theonlyengine
commited on
Upload 421 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- AUTHORS +1 -0
- Dockerfile +91 -0
- LICENSE +29 -0
- MANIFEST.in +11 -0
- Makefile +9 -0
- README.md +231 -0
- __init__.py +1 -0
- acc.yaml +3 -0
- acc_ignore_index.yaml +4 -0
- acctop5.yaml +4 -0
- activations.py +135 -0
- adam.yaml +2 -0
- adamw-apex-distributed.yaml +3 -0
- adamw-apex-zero.yaml +7 -0
- adamw-apex.yaml +3 -0
- adamw-zero.yaml +7 -0
- adamw.yaml +2 -0
- alibi.h +74 -0
- all_params.yaml +49 -0
- baichuan.py +151 -0
- base.yaml +82 -0
- benchmark.py +268 -0
- benchmark_alibi.py +275 -0
- benchmark_attn.py +314 -0
- benchmark_causal.py +225 -0
- benchmark_flash_attention.py +180 -0
- benchmark_flash_attention_fp8.py +333 -0
- benchmark_gemm.py +43 -0
- bert.py +764 -0
- bert_padding.py +213 -0
- bigcode.py +233 -0
- block.py +397 -0
- block_info.h +46 -0
- btlm.py +102 -0
- causality-monitor.yaml +2 -0
- comet.yaml +7 -0
- config.yaml +50 -0
- cosine-warmup-timm.yaml +2 -0
- cosine-warmup.yaml +2 -0
- cross_entropy.py +318 -0
- csv.yaml +8 -0
- cuda_bf16_fallbacks.cuh +257 -0
- cuda_bf16_wrapper.h +23 -0
- ddp.yaml +6 -0
- debug.yaml +27 -0
- decoder_masked_multihead_attention.cu +149 -0
- decoder_masked_multihead_attention.h +192 -0
- decoder_masked_multihead_attention_template.hpp +1619 -0
- decoder_masked_multihead_attention_utils.h +2017 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
|
AUTHORS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Tri Dao, trid@cs.stanford.edu
|
Dockerfile
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
|
2 |
+
# ARG COMPAT=0
|
3 |
+
ARG PERSONAL=0
|
4 |
+
# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
|
5 |
+
FROM nvcr.io/nvidia/pytorch:22.12-py3 as base
|
6 |
+
|
7 |
+
ENV HOST docker
|
8 |
+
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
9 |
+
# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
|
10 |
+
ENV TZ America/Los_Angeles
|
11 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
12 |
+
|
13 |
+
# git for installing dependencies
|
14 |
+
# tzdata to set time zone
|
15 |
+
# wget and unzip to download data
|
16 |
+
# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
|
17 |
+
# [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
|
18 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
19 |
+
build-essential \
|
20 |
+
cmake \
|
21 |
+
curl \
|
22 |
+
ca-certificates \
|
23 |
+
sudo \
|
24 |
+
less \
|
25 |
+
htop \
|
26 |
+
git \
|
27 |
+
tzdata \
|
28 |
+
wget \
|
29 |
+
tmux \
|
30 |
+
zip \
|
31 |
+
unzip \
|
32 |
+
zsh stow subversion fasd \
|
33 |
+
&& rm -rf /var/lib/apt/lists/*
|
34 |
+
# openmpi-bin \
|
35 |
+
|
36 |
+
# Allow running runmpi as root
|
37 |
+
# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
|
38 |
+
|
39 |
+
# # Create a non-root user and switch to it
|
40 |
+
# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
|
41 |
+
# && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
|
42 |
+
# USER user
|
43 |
+
|
44 |
+
# All users can use /home/user as their home directory
|
45 |
+
ENV HOME=/home/user
|
46 |
+
RUN mkdir -p /home/user && chmod 777 /home/user
|
47 |
+
WORKDIR /home/user
|
48 |
+
|
49 |
+
# Set up personal environment
|
50 |
+
# FROM base-${COMPAT} as env-0
|
51 |
+
FROM base as env-0
|
52 |
+
FROM env-0 as env-1
|
53 |
+
# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
|
54 |
+
# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
|
55 |
+
ONBUILD COPY dotfiles ./dotfiles
|
56 |
+
ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
|
57 |
+
# nvcr pytorch image sets SHELL=/bin/bash
|
58 |
+
ONBUILD ENV SHELL=/bin/zsh
|
59 |
+
|
60 |
+
FROM env-${PERSONAL} as packages
|
61 |
+
|
62 |
+
# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
|
63 |
+
ENV PIP_NO_CACHE_DIR=1
|
64 |
+
|
65 |
+
# # apex and pytorch-fast-transformers take a while to compile so we install them first
|
66 |
+
# TD [2022-04-28] apex is already installed. In case we need a newer commit:
|
67 |
+
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
|
68 |
+
|
69 |
+
# xgboost conflicts with deepspeed
|
70 |
+
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7
|
71 |
+
|
72 |
+
# General packages that we don't care about the version
|
73 |
+
# zstandard to extract the_pile dataset
|
74 |
+
# psutil to get the number of cpu physical cores
|
75 |
+
# twine to upload package to PyPI
|
76 |
+
RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \
|
77 |
+
&& python -m spacy download en_core_web_sm
|
78 |
+
# hydra
|
79 |
+
RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
|
80 |
+
# Core packages
|
81 |
+
RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 triton==2.0.0.dev20221202 wandb==0.13.7 timm==0.6.12 torchmetrics==0.10.3
|
82 |
+
# torchmetrics 0.11.0 broke hydra's instantiate
|
83 |
+
|
84 |
+
# For MLPerf
|
85 |
+
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
86 |
+
|
87 |
+
# Install FlashAttention
|
88 |
+
RUN pip install flash-attn==2.6.3
|
89 |
+
|
90 |
+
# Install CUDA extensions for fused dense
|
91 |
+
RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
|
LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
* Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
* Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
MANIFEST.in
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
recursive-include csrc *.cu
|
2 |
+
recursive-include csrc *.h
|
3 |
+
recursive-include csrc *.cuh
|
4 |
+
recursive-include csrc *.cpp
|
5 |
+
recursive-include csrc *.hpp
|
6 |
+
|
7 |
+
recursive-include flash_attn *.cu
|
8 |
+
recursive-include flash_attn *.h
|
9 |
+
recursive-include flash_attn *.cuh
|
10 |
+
recursive-include flash_attn *.cpp
|
11 |
+
recursive-include flash_attn *.hpp
|
Makefile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
clean_dist:
|
3 |
+
rm -rf dist/*
|
4 |
+
|
5 |
+
create_dist: clean_dist
|
6 |
+
python setup.py sdist
|
7 |
+
|
8 |
+
upload_package: create_dist
|
9 |
+
twine upload dist/*
|
README.md
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Optimized Transformer implementation
|
2 |
+
This repo contains examples of how FlashAttention can be integrated into a model
|
3 |
+
(e.g., GPT, ViT) and trained end-to-end. We also provide optimized
|
4 |
+
implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,
|
5 |
+
rotary embedding). Overall this speeds up training by 3-5x compared to the
|
6 |
+
baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,
|
7 |
+
equivalent to 60.6\% model FLOPs utilization (we don't need any activation
|
8 |
+
checkpointing). All without changing the model architecture (i.e., no
|
9 |
+
approximation).
|
10 |
+
|
11 |
+
Goals:
|
12 |
+
- Performance: we optimize for model speed and memory, especially on 1-node
|
13 |
+
(e.g., with 8 A100s).
|
14 |
+
- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),
|
15 |
+
and the model code illustrates how these components can be put together.
|
16 |
+
The training code also aims to be model- & task-agnostic.
|
17 |
+
|
18 |
+
Non-goals (and other resources):
|
19 |
+
- Support as many models as possible: Huggingface's
|
20 |
+
[transformers](https://github.com/huggingface/transformers) and
|
21 |
+
[timm](https://github.com/rwightman/pytorch-image-models/) are great for this.
|
22 |
+
- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node
|
23 |
+
training for models up to 2.7B parameters. However, if you're looking for large-scale distributed
|
24 |
+
training techniques (e.g., pipeline parallelism, tensor parallelism),
|
25 |
+
check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and
|
26 |
+
[DeepSpeed](https://github.com/microsoft/deepspeed).
|
27 |
+
- Inference: we currently focus on training (this might change in the future).
|
28 |
+
If you want fast inference, take a look at
|
29 |
+
[FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
|
30 |
+
- Production: this codebase was written during several research projects to validate ideas
|
31 |
+
on speeding up ML models.
|
32 |
+
|
33 |
+
## Model Components
|
34 |
+
|
35 |
+
The GPT model is implemented
|
36 |
+
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
|
37 |
+
And here's an example to construct the GPT3-1.3B model with rotary embedding:
|
38 |
+
```python
|
39 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
40 |
+
from flash_attn.models.gpt import GPTLMHeadModel
|
41 |
+
|
42 |
+
seqlen = 2048
|
43 |
+
hidden_dim = 2048
|
44 |
+
nheads = 16
|
45 |
+
n_layer = 24
|
46 |
+
rotary_emb_fraction = 0.5
|
47 |
+
config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
|
48 |
+
n_layer=n_layer, n_head=nheads,
|
49 |
+
scale_attn_by_inverse_layer_idx=True,
|
50 |
+
rotary_emb_fraction=rotary_emb_fraction,
|
51 |
+
use_flash_attn=True, fused_mlp=True,
|
52 |
+
fused_bias_fc=True, fused_dropout_add_ln=True,
|
53 |
+
pad_vocab_size_multiple=8)
|
54 |
+
model = GPTLMHeadModel(config)
|
55 |
+
```
|
56 |
+
|
57 |
+
We provide the following optimized components:
|
58 |
+
|
59 |
+
1. FlashAttention: fast and memory-efficient exact attention. This makes
|
60 |
+
attention much faster and saves a lot of activation memory. As a result we don't need
|
61 |
+
to use any activation checkpointing.
|
62 |
+
```sh
|
63 |
+
pip install flash-attn
|
64 |
+
```
|
65 |
+
|
66 |
+
2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
67 |
+
(forward and backward), adapted from Apex's
|
68 |
+
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
|
69 |
+
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
|
70 |
+
this doesn't have the best matmul + bias + gelu performance for bfloat16.
|
71 |
+
```sh
|
72 |
+
cd ../csrc/fused_dense_lib && pip install .
|
73 |
+
```
|
74 |
+
3. Optimized cross-entropy loss, adapted from Apex's
|
75 |
+
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
|
76 |
+
```sh
|
77 |
+
cd ../csrc/xentropy && pip install .
|
78 |
+
```
|
79 |
+
4. Fused rotary embedding:
|
80 |
+
```sh
|
81 |
+
cd ../csrc/rotary && pip install .
|
82 |
+
```
|
83 |
+
5. Fused dropout + residual + LayerNorm, adapted from Apex's
|
84 |
+
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
|
85 |
+
This supports dimensions divisible by 8, up to 6144.
|
86 |
+
```sh
|
87 |
+
cd ../csrc/layer_norm && pip install .
|
88 |
+
```
|
89 |
+
|
90 |
+
## Training
|
91 |
+
|
92 |
+
We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on
|
93 |
+
The Pile as examples. Feel free to use the model in your own training setup as
|
94 |
+
well.
|
95 |
+
|
96 |
+
We use [Hydra](https://hydra.cc/) for configuration,
|
97 |
+
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
|
98 |
+
[Wandb](https://wandb.ai/) for logging.
|
99 |
+
|
100 |
+
We use the template from `https://github.com/ashleve/lightning-hydra-template`.
|
101 |
+
Please read the instructions there to understand the repo structure.
|
102 |
+
|
103 |
+
### Requirements
|
104 |
+
|
105 |
+
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
|
106 |
+
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
|
107 |
+
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
108 |
+
|
109 |
+
We provide a Dockerfile that lists all the required packages.
|
110 |
+
|
111 |
+
### Dataset preparation
|
112 |
+
|
113 |
+
Running the training command would automatically download the datasets
|
114 |
+
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
|
115 |
+
tokens, then save this cache to disk. Alternatively, you can also prepare the
|
116 |
+
datasets as a separate step.
|
117 |
+
|
118 |
+
The cached datasets are saved to `${DATA_DIR}/openwebtext` and
|
119 |
+
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
|
120 |
+
`./data/{openwebtext,the_pile}`.
|
121 |
+
|
122 |
+
- Openwebtext:
|
123 |
+
```sh
|
124 |
+
export PYTHONPATH=$PWD:$PYTHONPATH
|
125 |
+
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext"
|
126 |
+
```
|
127 |
+
This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
|
128 |
+
|
129 |
+
- The Pile:
|
130 |
+
```sh
|
131 |
+
export PYTHONPATH=$PWD:$PYTHONPATH
|
132 |
+
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
|
133 |
+
```
|
134 |
+
This takes around 20h on a 64-core CPU. The processed dataset has size 699GB.
|
135 |
+
|
136 |
+
### GPT2 training on Openwebtext
|
137 |
+
To train GPT2 on Openwebtext with 8 GPUs:
|
138 |
+
```sh
|
139 |
+
python run.py experiment=owt/gpt2s-flash trainer.devices=8 # 125M
|
140 |
+
python run.py experiment=owt/gpt2m-flash trainer.devices=8 # 355M
|
141 |
+
python run.py experiment=owt/gpt2l-flash trainer.devices=8 # 760M
|
142 |
+
python run.py experiment=owt/gpt2xl-flash trainer.devices=8 # 1.6B
|
143 |
+
```
|
144 |
+
The default parameters are set for 8 x A100 80GB.
|
145 |
+
|
146 |
+
To train with bf16 instead of fp16, add `trainer.precision=bf16`.
|
147 |
+
|
148 |
+
### GPT3 training on The Pile
|
149 |
+
To train GPT3 on The Pile with 8 GPUs:
|
150 |
+
```sh
|
151 |
+
python run.py experiment=pile/gpt3s-flash trainer.devices=8 # 125M
|
152 |
+
python run.py experiment=pile/gpt3m-flash trainer.devices=8 # 355M
|
153 |
+
python run.py experiment=pile/gpt3l-flash trainer.devices=8 # 760M
|
154 |
+
python run.py experiment=pile/gpt3xl-flash trainer.devices=8 # 1.3B
|
155 |
+
python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8 # 2.7B
|
156 |
+
```
|
157 |
+
The default parameters are set for 8 x A100 80GB. We train with bf16 by default.
|
158 |
+
|
159 |
+
To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl}-flash-rotary`.
|
160 |
+
|
161 |
+
### Training options
|
162 |
+
|
163 |
+
**Gradient accumulation**: to adjust device batch size to fit into GPU memory
|
164 |
+
(the global batch size stays the same, and gradient accumulation is calculated
|
165 |
+
automatically), set `datamodule.batch_size=blah`.
|
166 |
+
|
167 |
+
**Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.
|
168 |
+
|
169 |
+
**Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.
|
170 |
+
|
171 |
+
**Resumable training**: set a name to the run, and then set `resume=True` when
|
172 |
+
you resume. Training will restart at exactly the same batch.
|
173 |
+
```sh
|
174 |
+
python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True
|
175 |
+
```
|
176 |
+
|
177 |
+
## Training speed
|
178 |
+
|
179 |
+
We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.
|
180 |
+
|
181 |
+
FLOPs are calculated using the formula from the [Megatron-LM
|
182 |
+
paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4
|
183 |
+
to get the model FLOPs (instead of hardware FLOPs with activation
|
184 |
+
checkpointing).
|
185 |
+
|
186 |
+
|
187 |
+
### GPT2 (sequence length 1024)
|
188 |
+
|
189 |
+
![GPT2 speedup](../assets/gpt2_training_efficiency.jpg)
|
190 |
+
|
191 |
+
The implementation in this repo (FlashAttention) is 3-4x faster than the
|
192 |
+
baseline implementation from Huggingface.
|
193 |
+
|
194 |
+
### GPT3 (sequence length 2048)
|
195 |
+
|
196 |
+
![GPT3 speedup](../assets/gpt3_training_efficiency.jpg)
|
197 |
+
|
198 |
+
The implementation in this repo (FlashAttention) is 3-5x faster than the
|
199 |
+
baseline implementation from Huggingface.
|
200 |
+
|
201 |
+
For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.
|
202 |
+
|
203 |
+
We include here more details on the training speed with FlashAttention on 8 x
|
204 |
+
A100 80GB.
|
205 |
+
|
206 |
+
| Model | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |
|
207 |
+
| --------- | ------------------- | ------------------------ | ----------------- |
|
208 |
+
| GPT3-125M | 0.5M | 1310k | 0.21 |
|
209 |
+
| GPT3-355M | 0.5M | 503k | 0.55 |
|
210 |
+
| GPT3-760M | 0.5M | 245k | 1.13 |
|
211 |
+
| GPT3-1.3B | 1M | 169k | 1.64 |
|
212 |
+
| GPT3-2.7B | 1M | 85k | 3.27 |
|
213 |
+
|
214 |
+
As an example, this means that one can train a GPT3-1.3B model on 26B tokens
|
215 |
+
(compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.
|
216 |
+
|
217 |
+
## Training quality
|
218 |
+
|
219 |
+
We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.
|
220 |
+
For GPT2, the runs with FlashAttention yield the same loss curve as the runs
|
221 |
+
with the baseline implementation from Huggingface for 125M and 355M models. For
|
222 |
+
larger models the baseline implementation just takes too long.
|
223 |
+
|
224 |
+
![GPT2 training curve](../assets/gpt2_training_curve.jpg)
|
225 |
+
|
226 |
+
We include here the loss curve for GPT3 on The Pile, trained for 400B tokens.
|
227 |
+
The 125M, 355M, 760M models have batch size 512k tokens so this translates to
|
228 |
+
800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,
|
229 |
+
which translates to 400k training steps.
|
230 |
+
|
231 |
+
![GPT3 training curve](../assets/gpt3_training_curve.jpg)
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "3.0.0.b1"
|
acc.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# @package eval.metrics
|
2 |
+
acc:
|
3 |
+
_target_: src.metrics.accuracy.AccuracyMine
|
acc_ignore_index.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package eval.metrics
|
2 |
+
acc:
|
3 |
+
_target_: torchmetrics.Accuracy
|
4 |
+
ignore_index: -100
|
acctop5.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package eval.metrics
|
2 |
+
acctop5:
|
3 |
+
_target_: src.metrics.accuracy.AccuracyMine
|
4 |
+
top_k: 5
|
activations.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
# 1/sqrt(2*pi)-> 0.3989423
|
9 |
+
# 1/sqrt(2) -> 0.70710678
|
10 |
+
# sqrt(2/pi) -> 0.79788456
|
11 |
+
|
12 |
+
# this function is tanh approximation of gelu
|
13 |
+
# actual gelu is:
|
14 |
+
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
15 |
+
@torch.jit.script
|
16 |
+
def bias_gelu(y, bias):
|
17 |
+
x = bias + y
|
18 |
+
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
|
19 |
+
|
20 |
+
|
21 |
+
# gradient of tanh approximation of gelu
|
22 |
+
# gradient of actual gelu is:
|
23 |
+
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
24 |
+
@torch.jit.script
|
25 |
+
def bias_gelu_back(g, y, bias):
|
26 |
+
"""Assume that y has shape (B, D) and bias has shape (D)"""
|
27 |
+
x = bias + y
|
28 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
29 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
30 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
31 |
+
1 + tanh_out
|
32 |
+
)
|
33 |
+
grad_y = ff * g
|
34 |
+
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
|
35 |
+
|
36 |
+
|
37 |
+
class GeLUFunction(torch.autograd.Function):
|
38 |
+
@staticmethod
|
39 |
+
# bias is an optional argument
|
40 |
+
def forward(ctx, input, bias):
|
41 |
+
ctx.save_for_backward(input, bias)
|
42 |
+
return bias_gelu(input, bias)
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def backward(ctx, grad_output):
|
46 |
+
input, bias = ctx.saved_tensors
|
47 |
+
tmp = bias_gelu_back(grad_output, input, bias)
|
48 |
+
return tmp, tmp
|
49 |
+
|
50 |
+
|
51 |
+
bias_gelu_impl = GeLUFunction.apply
|
52 |
+
|
53 |
+
# this function is tanh approximation of gelu
|
54 |
+
# actual gelu is:
|
55 |
+
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
56 |
+
@torch.jit.script
|
57 |
+
def gelu_fwd(x):
|
58 |
+
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
|
59 |
+
|
60 |
+
|
61 |
+
# gradient of tanh approximation of gelu
|
62 |
+
# gradient of actual gelu is:
|
63 |
+
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
64 |
+
@torch.jit.script
|
65 |
+
def gelu_bwd(g, x):
|
66 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
67 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
68 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
69 |
+
1 + tanh_out
|
70 |
+
)
|
71 |
+
return (ff * g).to(dtype=x.dtype)
|
72 |
+
|
73 |
+
|
74 |
+
class FastGeLUFunction(torch.autograd.Function):
|
75 |
+
@staticmethod
|
76 |
+
# bias is an optional argument
|
77 |
+
def forward(ctx, input):
|
78 |
+
ctx.save_for_backward(input)
|
79 |
+
return gelu_fwd(input)
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def backward(ctx, grad_output):
|
83 |
+
(input,) = ctx.saved_tensors
|
84 |
+
tmp = gelu_bwd(grad_output, input)
|
85 |
+
return tmp
|
86 |
+
|
87 |
+
|
88 |
+
fast_gelu_impl = FastGeLUFunction.apply
|
89 |
+
|
90 |
+
|
91 |
+
@torch.jit.script
|
92 |
+
def relu_bwd(g, x):
|
93 |
+
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
94 |
+
|
95 |
+
|
96 |
+
@torch.jit.script
|
97 |
+
def sqrelu_fwd(x):
|
98 |
+
r = F.relu(x)
|
99 |
+
return (r * r).to(dtype=x.dtype)
|
100 |
+
|
101 |
+
|
102 |
+
@torch.jit.script
|
103 |
+
def sqrelu_bwd(g, x):
|
104 |
+
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
105 |
+
|
106 |
+
|
107 |
+
swiglu_fwd_codestring = """
|
108 |
+
template <typename T> T swiglu_fwd(T x, T y) {
|
109 |
+
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
|
110 |
+
}
|
111 |
+
"""
|
112 |
+
swiglu_bwd_codestring = """
|
113 |
+
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
114 |
+
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
115 |
+
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
116 |
+
dy = float(x) * x_sigmoid * float(g);
|
117 |
+
}
|
118 |
+
"""
|
119 |
+
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
|
120 |
+
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
|
121 |
+
|
122 |
+
|
123 |
+
class SwiGLUFunction(torch.autograd.Function):
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def forward(ctx, x, y):
|
127 |
+
ctx.save_for_backward(x, y)
|
128 |
+
return swiglu_fwd(x, y)
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def backward(ctx, dout):
|
132 |
+
x, y = ctx.saved_tensors
|
133 |
+
return swiglu_bwd(x, y, dout)
|
134 |
+
|
135 |
+
swiglu = SwiGLUFunction.apply
|
adam.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# @package train.optimizer
|
2 |
+
_target_: torch.optim.Adam
|
adamw-apex-distributed.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# @package train.optimizer
|
2 |
+
_target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam
|
3 |
+
adam_w_mode: True
|
adamw-apex-zero.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package train.optimizer
|
2 |
+
_target_: torch.distributed.optim.ZeroRedundancyOptimizer
|
3 |
+
_recursive_: True
|
4 |
+
optimizer_class:
|
5 |
+
_target_: apex.optimizers.FusedAdam
|
6 |
+
_partial_: True
|
7 |
+
adam_w_mode: True
|
adamw-apex.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# @package train.optimizer
|
2 |
+
_target_: apex.optimizers.FusedAdam
|
3 |
+
adam_w_mode: True
|
adamw-zero.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package train.optimizer
|
2 |
+
_target_: torch.distributed.optim.ZeroRedundancyOptimizer
|
3 |
+
_recursive_: True
|
4 |
+
optimizer_class:
|
5 |
+
_target_: torch.optim.__getattribute__
|
6 |
+
_args_:
|
7 |
+
- "AdamW"
|
adamw.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# @package train.optimizer
|
2 |
+
_target_: torch.optim.AdamW
|
alibi.h
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cmath>
|
2 |
+
|
3 |
+
#include <cute/tensor.hpp>
|
4 |
+
|
5 |
+
#include <cutlass/cutlass.h>
|
6 |
+
#include <cutlass/array.h>
|
7 |
+
|
8 |
+
#include "utils.h"
|
9 |
+
|
10 |
+
namespace flash {
|
11 |
+
|
12 |
+
using namespace cute;
|
13 |
+
|
14 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
15 |
+
|
16 |
+
template <bool Is_causal>
|
17 |
+
struct Alibi {
|
18 |
+
|
19 |
+
const float alibi_slope;
|
20 |
+
const int max_seqlen_k, max_seqlen_q;
|
21 |
+
|
22 |
+
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
23 |
+
: alibi_slope(alibi_slope)
|
24 |
+
, max_seqlen_k(max_seqlen_k)
|
25 |
+
, max_seqlen_q(max_seqlen_q) {
|
26 |
+
};
|
27 |
+
|
28 |
+
|
29 |
+
template <typename Engine, typename Layout>
|
30 |
+
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
31 |
+
const int col_idx_offset_,
|
32 |
+
const int row_idx_offset,
|
33 |
+
const int warp_row_stride) {
|
34 |
+
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
35 |
+
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
36 |
+
const int lane_id = threadIdx.x % 32;
|
37 |
+
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
38 |
+
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
39 |
+
#pragma unroll
|
40 |
+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
41 |
+
const int col_idx_base = col_idx_offset + nj * 8;
|
42 |
+
#pragma unroll
|
43 |
+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
44 |
+
const int col_idx = col_idx_base + j;
|
45 |
+
#pragma unroll
|
46 |
+
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
47 |
+
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
}
|
51 |
+
} else { // Bias depends on both row_idx and col_idx
|
52 |
+
#pragma unroll
|
53 |
+
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
54 |
+
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
55 |
+
#pragma unroll
|
56 |
+
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
57 |
+
const int row_idx = row_idx_base + i * 8;
|
58 |
+
#pragma unroll
|
59 |
+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
60 |
+
const int col_idx_base = col_idx_offset + nj * 8;
|
61 |
+
#pragma unroll
|
62 |
+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
63 |
+
const int col_idx = col_idx_base + j;
|
64 |
+
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
65 |
+
}
|
66 |
+
}
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
};
|
73 |
+
|
74 |
+
} // namespace flash
|
all_params.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: pytorch_lightning.Trainer
|
2 |
+
|
3 |
+
# default values for all trainer parameters
|
4 |
+
checkpoint_callback: True
|
5 |
+
default_root_dir: null
|
6 |
+
gradient_clip_val: 0.0
|
7 |
+
process_position: 0
|
8 |
+
num_nodes: 1
|
9 |
+
num_processes: 1
|
10 |
+
gpus: null
|
11 |
+
auto_select_gpus: False
|
12 |
+
tpu_cores: null
|
13 |
+
log_gpu_memory: null
|
14 |
+
overfit_batches: 0.0
|
15 |
+
track_grad_norm: -1
|
16 |
+
check_val_every_n_epoch: 1
|
17 |
+
fast_dev_run: False
|
18 |
+
accumulate_grad_batches: 1
|
19 |
+
max_epochs: 1
|
20 |
+
min_epochs: 1
|
21 |
+
max_steps: null
|
22 |
+
min_steps: null
|
23 |
+
limit_train_batches: 1.0
|
24 |
+
limit_val_batches: 1.0
|
25 |
+
limit_test_batches: 1.0
|
26 |
+
val_check_interval: 1.0
|
27 |
+
flush_logs_every_n_steps: 100
|
28 |
+
log_every_n_steps: 50
|
29 |
+
accelerator: null
|
30 |
+
sync_batchnorm: False
|
31 |
+
precision: 32
|
32 |
+
weights_summary: "top"
|
33 |
+
weights_save_path: null
|
34 |
+
num_sanity_val_steps: 2
|
35 |
+
truncated_bptt_steps: null
|
36 |
+
resume_from_checkpoint: null
|
37 |
+
profiler: null
|
38 |
+
benchmark: False
|
39 |
+
deterministic: False
|
40 |
+
reload_dataloaders_every_epoch: False
|
41 |
+
auto_lr_find: False
|
42 |
+
replace_sampler_ddp: True
|
43 |
+
terminate_on_nan: False
|
44 |
+
auto_scale_batch_size: False
|
45 |
+
prepare_data_per_node: True
|
46 |
+
plugins: null
|
47 |
+
amp_backend: "native"
|
48 |
+
amp_level: "O2"
|
49 |
+
move_metrics_to_cpu: False
|
baichuan.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, GGGGGGXY, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
from transformers import GPT2Config, AutoConfig, PretrainedConfig
|
15 |
+
|
16 |
+
|
17 |
+
def remap_state_dict_hf_baichuan(state_dict, config):
|
18 |
+
def key_mapping_layers(key):
|
19 |
+
return re.sub(r"^model.", "transformer.", key)
|
20 |
+
|
21 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
22 |
+
|
23 |
+
# Word embedding
|
24 |
+
def key_mapping_emb(key):
|
25 |
+
return re.sub(
|
26 |
+
r"^transformer.embed_tokens.",
|
27 |
+
"transformer.embeddings.word_embeddings.",
|
28 |
+
key,
|
29 |
+
)
|
30 |
+
|
31 |
+
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
32 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
33 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
34 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
35 |
+
vocab_size = (
|
36 |
+
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
37 |
+
* pad_vocab_size_multiple
|
38 |
+
)
|
39 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
40 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
41 |
+
)
|
42 |
+
if getattr(config, "tie_word_embeddings"):
|
43 |
+
state_dict["lm_head.weight"] = state_dict[
|
44 |
+
"transformer.embeddings.word_embeddings.weight"
|
45 |
+
]
|
46 |
+
else:
|
47 |
+
output_embeddings = state_dict.pop("lm_head.weight")
|
48 |
+
# Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings
|
49 |
+
# differently.
|
50 |
+
vocab_size = (
|
51 |
+
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
52 |
+
* pad_vocab_size_multiple
|
53 |
+
)
|
54 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
55 |
+
state_dict["lm_head.weight"] = F.pad(
|
56 |
+
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
57 |
+
)
|
58 |
+
|
59 |
+
# LayerNorm
|
60 |
+
def key_mapping_ln(key):
|
61 |
+
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
|
62 |
+
key = re.sub(
|
63 |
+
r"^transformer.layers.(\d+).input_layernorm.",
|
64 |
+
r"transformer.layers.\1.norm1.",
|
65 |
+
key,
|
66 |
+
)
|
67 |
+
key = re.sub(
|
68 |
+
r"^transformer.layers.(\d+).post_attention_layernorm.",
|
69 |
+
r"transformer.layers.\1.norm2.",
|
70 |
+
key,
|
71 |
+
)
|
72 |
+
return key
|
73 |
+
|
74 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
75 |
+
|
76 |
+
# MLP
|
77 |
+
for l in range(config.n_layer):
|
78 |
+
w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight")
|
79 |
+
w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight")
|
80 |
+
# Our ordering is different
|
81 |
+
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat(
|
82 |
+
[w3, w1], dim=0
|
83 |
+
)
|
84 |
+
|
85 |
+
def key_mapping_mlp(key):
|
86 |
+
return re.sub(
|
87 |
+
r"^transformer.layers.(\d+).mlp.down_proj.",
|
88 |
+
r"transformer.layers.\1.mlp.fc2.",
|
89 |
+
key,
|
90 |
+
)
|
91 |
+
|
92 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
93 |
+
|
94 |
+
# Attention
|
95 |
+
def key_mapping_attn(key):
|
96 |
+
key = re.sub(
|
97 |
+
r"^transformer.layers.(\d+).self_attn.W_pack.",
|
98 |
+
r"transformer.layers.\1.mixer.Wqkv.",
|
99 |
+
key,
|
100 |
+
)
|
101 |
+
key = re.sub(
|
102 |
+
r"^transformer.layers.(\d+).self_attn.o_proj.",
|
103 |
+
r"transformer.layers.\1.mixer.out_proj.",
|
104 |
+
key,
|
105 |
+
)
|
106 |
+
return key
|
107 |
+
|
108 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
109 |
+
for l in range(config.n_layer):
|
110 |
+
# pop rotary_emb.inv_freq from state dict
|
111 |
+
state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
112 |
+
return state_dict
|
113 |
+
|
114 |
+
|
115 |
+
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
|
116 |
+
# HACK: the config doesn't have say whether it's rotary or alibi.
|
117 |
+
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
|
118 |
+
# HACK: the config doesn't have say whether it uses norm head.
|
119 |
+
# So we have to infer from the vocab size
|
120 |
+
# (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
|
121 |
+
use_rotary = baichuan_config.hidden_size < 5000
|
122 |
+
return GPT2Config(
|
123 |
+
vocab_size=baichuan_config.vocab_size,
|
124 |
+
n_positions=0, # No absolute position embedding
|
125 |
+
n_embd=baichuan_config.hidden_size,
|
126 |
+
n_layer=baichuan_config.num_hidden_layers,
|
127 |
+
n_head=baichuan_config.num_attention_heads,
|
128 |
+
n_inner=baichuan_config.intermediate_size,
|
129 |
+
activation_function="swiglu", # Hardcode since HF calls it 'silu'
|
130 |
+
# baichuan doesn't have dropout, idk if it's because they only release the inference code
|
131 |
+
resid_pdrop=0.0,
|
132 |
+
embd_pdrop=0.0,
|
133 |
+
attn_pdrop=0.0,
|
134 |
+
layer_norm_epsilon=baichuan_config.rms_norm_eps,
|
135 |
+
initializer_range=baichuan_config.initializer_range,
|
136 |
+
bos_token_id=baichuan_config.bos_token_id,
|
137 |
+
eos_token_id=baichuan_config.eos_token_id,
|
138 |
+
# These are new arguments not in the original GPT2Config
|
139 |
+
pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
|
140 |
+
rms_norm=True,
|
141 |
+
rotary_emb_fraction=1.0 if use_rotary else 0.0,
|
142 |
+
rotary_emb_interleaved=False,
|
143 |
+
use_alibi=not use_rotary,
|
144 |
+
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
|
145 |
+
tie_word_embeddings=False,
|
146 |
+
norm_head=baichuan_config.vocab_size > 70000,
|
147 |
+
qkv_proj_bias=False,
|
148 |
+
out_proj_bias=False,
|
149 |
+
mlp_fc1_bias=False,
|
150 |
+
mlp_fc2_bias=False,
|
151 |
+
)
|
base.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /trainer: default # choose trainer from 'configs/trainer/'
|
4 |
+
- override /model: null
|
5 |
+
- override /datamodule: openwebtext
|
6 |
+
# FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
|
7 |
+
# per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
|
8 |
+
# For GPT2-medium time per global goes from 997ms to 972ms.
|
9 |
+
- override /optimizer: adamw-apex
|
10 |
+
- override /scheduler: linear-warmup
|
11 |
+
- override /callbacks: [default, norm-monitor]
|
12 |
+
- override /metrics: [perplexity, num-tokens]
|
13 |
+
- override /logger: wandb
|
14 |
+
|
15 |
+
# all parameters below will be merged with parameters from default configurations set above
|
16 |
+
# this allows you to overwrite only specified parameters
|
17 |
+
|
18 |
+
task:
|
19 |
+
_target_: src.tasks.seq.SequenceLMModel
|
20 |
+
|
21 |
+
seed: 1111
|
22 |
+
|
23 |
+
trainer:
|
24 |
+
accelerator: gpu
|
25 |
+
devices: 8
|
26 |
+
num_nodes: 1
|
27 |
+
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
|
28 |
+
max_steps: 400000
|
29 |
+
val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
|
30 |
+
check_val_every_n_epoch: null # We don't care about epoch boundary
|
31 |
+
precision: 16
|
32 |
+
gradient_clip_val: 1.0
|
33 |
+
strategy: null
|
34 |
+
|
35 |
+
datamodule:
|
36 |
+
batch_size: 16 # Per GPU
|
37 |
+
batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k
|
38 |
+
max_length: 1024
|
39 |
+
fault_tolerant: True
|
40 |
+
ddp: ${eval:"${trainer.devices} > 1"}
|
41 |
+
|
42 |
+
train:
|
43 |
+
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
|
44 |
+
global_batch_size: 512
|
45 |
+
optimizer:
|
46 |
+
lr: 6e-4
|
47 |
+
weight_decay: 0.1
|
48 |
+
optimizer_param_grouping:
|
49 |
+
bias_weight_decay: False
|
50 |
+
normalization_weight_decay: False
|
51 |
+
scheduler:
|
52 |
+
num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
|
53 |
+
num_training_steps: ${trainer.max_steps}
|
54 |
+
loss_fn:
|
55 |
+
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
|
56 |
+
# It's also more numerically stable if we're using DeepSpeed 16 bits.
|
57 |
+
_target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
|
58 |
+
inplace_backward: True # to save memory
|
59 |
+
|
60 |
+
eval:
|
61 |
+
log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step
|
62 |
+
|
63 |
+
callbacks:
|
64 |
+
model_checkpoint:
|
65 |
+
monitor: val/loss
|
66 |
+
mode: min
|
67 |
+
save_top_k: 3
|
68 |
+
save_last: True
|
69 |
+
every_n_train_steps: 1000
|
70 |
+
dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
|
71 |
+
filename: step_{step}
|
72 |
+
auto_insert_metric_name: False
|
73 |
+
model_checkpoint_progress:
|
74 |
+
_target_: src.callbacks.model_checkpoint.ModelCheckpointMine
|
75 |
+
fault_tolerant: True
|
76 |
+
every_n_train_steps: 50000
|
77 |
+
save_last: False
|
78 |
+
save_top_k: -1 # Save all the checkpoints
|
79 |
+
dirpath: ${..model_checkpoint.dirpath}
|
80 |
+
filename: progress_step_{step}
|
81 |
+
auto_insert_metric_name: False
|
82 |
+
early_stopping: null
|
benchmark.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
""" Useful functions for writing test code. """
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.benchmark as benchmark
|
6 |
+
|
7 |
+
|
8 |
+
def benchmark_forward(
|
9 |
+
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
|
10 |
+
):
|
11 |
+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
12 |
+
if verbose:
|
13 |
+
print(desc, "- Forward pass")
|
14 |
+
|
15 |
+
def amp_wrapper(*inputs, **kwinputs):
|
16 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
17 |
+
fn(*inputs, **kwinputs)
|
18 |
+
|
19 |
+
t = benchmark.Timer(
|
20 |
+
stmt="fn_amp(*inputs, **kwinputs)",
|
21 |
+
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
22 |
+
num_threads=torch.get_num_threads(),
|
23 |
+
)
|
24 |
+
m = t.timeit(repeats)
|
25 |
+
if verbose:
|
26 |
+
print(m)
|
27 |
+
return t, m
|
28 |
+
|
29 |
+
|
30 |
+
def benchmark_backward(
|
31 |
+
fn,
|
32 |
+
*inputs,
|
33 |
+
grad=None,
|
34 |
+
repeats=10,
|
35 |
+
desc="",
|
36 |
+
verbose=True,
|
37 |
+
amp=False,
|
38 |
+
amp_dtype=torch.float16,
|
39 |
+
**kwinputs,
|
40 |
+
):
|
41 |
+
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
|
42 |
+
if verbose:
|
43 |
+
print(desc, "- Backward pass")
|
44 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
45 |
+
y = fn(*inputs, **kwinputs)
|
46 |
+
if type(y) is tuple:
|
47 |
+
y = y[0]
|
48 |
+
if grad is None:
|
49 |
+
grad = torch.randn_like(y)
|
50 |
+
else:
|
51 |
+
if grad.shape != y.shape:
|
52 |
+
raise RuntimeError("Grad shape does not match output shape")
|
53 |
+
|
54 |
+
def f(*inputs, y, grad):
|
55 |
+
# Set .grad to None to avoid extra operation of gradient accumulation
|
56 |
+
for x in inputs:
|
57 |
+
if isinstance(x, torch.Tensor):
|
58 |
+
x.grad = None
|
59 |
+
y.backward(grad, retain_graph=True)
|
60 |
+
|
61 |
+
t = benchmark.Timer(
|
62 |
+
stmt="f(*inputs, y=y, grad=grad)",
|
63 |
+
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
|
64 |
+
num_threads=torch.get_num_threads(),
|
65 |
+
)
|
66 |
+
m = t.timeit(repeats)
|
67 |
+
if verbose:
|
68 |
+
print(m)
|
69 |
+
return t, m
|
70 |
+
|
71 |
+
|
72 |
+
def benchmark_combined(
|
73 |
+
fn,
|
74 |
+
*inputs,
|
75 |
+
grad=None,
|
76 |
+
repeats=10,
|
77 |
+
desc="",
|
78 |
+
verbose=True,
|
79 |
+
amp=False,
|
80 |
+
amp_dtype=torch.float16,
|
81 |
+
**kwinputs,
|
82 |
+
):
|
83 |
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
84 |
+
if verbose:
|
85 |
+
print(desc, "- Forward + Backward pass")
|
86 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
87 |
+
y = fn(*inputs, **kwinputs)
|
88 |
+
if type(y) is tuple:
|
89 |
+
y = y[0]
|
90 |
+
if grad is None:
|
91 |
+
grad = torch.randn_like(y)
|
92 |
+
else:
|
93 |
+
if grad.shape != y.shape:
|
94 |
+
raise RuntimeError("Grad shape does not match output shape")
|
95 |
+
|
96 |
+
def f(grad, *inputs, **kwinputs):
|
97 |
+
for x in inputs:
|
98 |
+
if isinstance(x, torch.Tensor):
|
99 |
+
x.grad = None
|
100 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
101 |
+
y = fn(*inputs, **kwinputs)
|
102 |
+
if type(y) is tuple:
|
103 |
+
y = y[0]
|
104 |
+
y.backward(grad, retain_graph=True)
|
105 |
+
|
106 |
+
t = benchmark.Timer(
|
107 |
+
stmt="f(grad, *inputs, **kwinputs)",
|
108 |
+
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
|
109 |
+
num_threads=torch.get_num_threads(),
|
110 |
+
)
|
111 |
+
m = t.timeit(repeats)
|
112 |
+
if verbose:
|
113 |
+
print(m)
|
114 |
+
return t, m
|
115 |
+
|
116 |
+
|
117 |
+
def benchmark_fwd_bwd(
|
118 |
+
fn,
|
119 |
+
*inputs,
|
120 |
+
grad=None,
|
121 |
+
repeats=10,
|
122 |
+
desc="",
|
123 |
+
verbose=True,
|
124 |
+
amp=False,
|
125 |
+
amp_dtype=torch.float16,
|
126 |
+
**kwinputs,
|
127 |
+
):
|
128 |
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
129 |
+
return (
|
130 |
+
benchmark_forward(
|
131 |
+
fn,
|
132 |
+
*inputs,
|
133 |
+
repeats=repeats,
|
134 |
+
desc=desc,
|
135 |
+
verbose=verbose,
|
136 |
+
amp=amp,
|
137 |
+
amp_dtype=amp_dtype,
|
138 |
+
**kwinputs,
|
139 |
+
),
|
140 |
+
benchmark_backward(
|
141 |
+
fn,
|
142 |
+
*inputs,
|
143 |
+
grad=grad,
|
144 |
+
repeats=repeats,
|
145 |
+
desc=desc,
|
146 |
+
verbose=verbose,
|
147 |
+
amp=amp,
|
148 |
+
amp_dtype=amp_dtype,
|
149 |
+
**kwinputs,
|
150 |
+
),
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def benchmark_all(
|
155 |
+
fn,
|
156 |
+
*inputs,
|
157 |
+
grad=None,
|
158 |
+
repeats=10,
|
159 |
+
desc="",
|
160 |
+
verbose=True,
|
161 |
+
amp=False,
|
162 |
+
amp_dtype=torch.float16,
|
163 |
+
**kwinputs,
|
164 |
+
):
|
165 |
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
166 |
+
return (
|
167 |
+
benchmark_forward(
|
168 |
+
fn,
|
169 |
+
*inputs,
|
170 |
+
repeats=repeats,
|
171 |
+
desc=desc,
|
172 |
+
verbose=verbose,
|
173 |
+
amp=amp,
|
174 |
+
amp_dtype=amp_dtype,
|
175 |
+
**kwinputs,
|
176 |
+
),
|
177 |
+
benchmark_backward(
|
178 |
+
fn,
|
179 |
+
*inputs,
|
180 |
+
grad=grad,
|
181 |
+
repeats=repeats,
|
182 |
+
desc=desc,
|
183 |
+
verbose=verbose,
|
184 |
+
amp=amp,
|
185 |
+
amp_dtype=amp_dtype,
|
186 |
+
**kwinputs,
|
187 |
+
),
|
188 |
+
benchmark_combined(
|
189 |
+
fn,
|
190 |
+
*inputs,
|
191 |
+
grad=grad,
|
192 |
+
repeats=repeats,
|
193 |
+
desc=desc,
|
194 |
+
verbose=verbose,
|
195 |
+
amp=amp,
|
196 |
+
amp_dtype=amp_dtype,
|
197 |
+
**kwinputs,
|
198 |
+
),
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
def pytorch_profiler(
|
203 |
+
fn,
|
204 |
+
*inputs,
|
205 |
+
trace_filename=None,
|
206 |
+
backward=False,
|
207 |
+
amp=False,
|
208 |
+
amp_dtype=torch.float16,
|
209 |
+
cpu=False,
|
210 |
+
verbose=True,
|
211 |
+
**kwinputs,
|
212 |
+
):
|
213 |
+
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
|
214 |
+
if backward:
|
215 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
216 |
+
out = fn(*inputs, **kwinputs)
|
217 |
+
if type(out) is tuple:
|
218 |
+
out = out[0]
|
219 |
+
g = torch.randn_like(out)
|
220 |
+
for _ in range(30): # Warm up
|
221 |
+
if backward:
|
222 |
+
for x in inputs:
|
223 |
+
if isinstance(x, torch.Tensor):
|
224 |
+
x.grad = None
|
225 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
226 |
+
out = fn(*inputs, **kwinputs)
|
227 |
+
if type(out) is tuple:
|
228 |
+
out = out[0]
|
229 |
+
# Backward should be done outside autocast
|
230 |
+
if backward:
|
231 |
+
out.backward(g, retain_graph=True)
|
232 |
+
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
|
233 |
+
torch.profiler.ProfilerActivity.CUDA
|
234 |
+
]
|
235 |
+
with torch.profiler.profile(
|
236 |
+
activities=activities,
|
237 |
+
record_shapes=True,
|
238 |
+
# profile_memory=True,
|
239 |
+
with_stack=True,
|
240 |
+
) as prof:
|
241 |
+
if backward:
|
242 |
+
for x in inputs:
|
243 |
+
if isinstance(x, torch.Tensor):
|
244 |
+
x.grad = None
|
245 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
246 |
+
out = fn(*inputs, **kwinputs)
|
247 |
+
if type(out) is tuple:
|
248 |
+
out = out[0]
|
249 |
+
if backward:
|
250 |
+
out.backward(g, retain_graph=True)
|
251 |
+
if verbose:
|
252 |
+
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
253 |
+
print(prof.key_averages().table(row_limit=50))
|
254 |
+
if trace_filename is not None:
|
255 |
+
prof.export_chrome_trace(trace_filename)
|
256 |
+
|
257 |
+
|
258 |
+
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
|
259 |
+
torch.cuda.empty_cache()
|
260 |
+
torch.cuda.reset_peak_memory_stats()
|
261 |
+
torch.cuda.synchronize()
|
262 |
+
fn(*inputs, **kwinputs)
|
263 |
+
torch.cuda.synchronize()
|
264 |
+
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
|
265 |
+
if verbose:
|
266 |
+
print(f"{desc} max memory: {mem}GB")
|
267 |
+
torch.cuda.empty_cache()
|
268 |
+
return mem
|
benchmark_alibi.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Sanghun Cho, Tri Dao.
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from flash_attn.layers.rotary import apply_rotary_emb
|
11 |
+
|
12 |
+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
13 |
+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
14 |
+
|
15 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
16 |
+
|
17 |
+
try:
|
18 |
+
import xformers.ops as xops
|
19 |
+
except ImportError:
|
20 |
+
xops = None
|
21 |
+
|
22 |
+
|
23 |
+
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
|
24 |
+
assert rotary_dim % 2 == 0
|
25 |
+
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
26 |
+
cos = torch.cos(angle).to(dtype=dtype)
|
27 |
+
sin = torch.sin(angle).to(dtype=dtype)
|
28 |
+
return cos, sin
|
29 |
+
|
30 |
+
|
31 |
+
def flash_rotary(q, k, v, cos, sin, causal=False):
|
32 |
+
# corrected by @tridao comments
|
33 |
+
q = apply_rotary_emb(
|
34 |
+
q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
|
35 |
+
)
|
36 |
+
k = apply_rotary_emb(
|
37 |
+
k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
|
38 |
+
)
|
39 |
+
|
40 |
+
return flash_attn_func(q, k, v, causal=causal)
|
41 |
+
|
42 |
+
|
43 |
+
def attn_bias_from_alibi_slopes(
|
44 |
+
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
|
45 |
+
):
|
46 |
+
batch, nheads = slopes.shape
|
47 |
+
device = slopes.device
|
48 |
+
slopes = rearrange(slopes, "b h -> b h 1 1")
|
49 |
+
if causal:
|
50 |
+
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
|
51 |
+
else:
|
52 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
53 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
54 |
+
sk = (
|
55 |
+
seqlen_k
|
56 |
+
if key_padding_mask is None
|
57 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
58 |
+
)
|
59 |
+
sq = (
|
60 |
+
seqlen_q
|
61 |
+
if query_padding_mask is None
|
62 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
63 |
+
)
|
64 |
+
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
|
65 |
+
return -slopes * relative_pos.to(dtype=slopes.dtype)
|
66 |
+
|
67 |
+
|
68 |
+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
69 |
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
70 |
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
71 |
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
72 |
+
|
73 |
+
|
74 |
+
def efficiency(flop, time):
|
75 |
+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
76 |
+
|
77 |
+
|
78 |
+
def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
|
79 |
+
"""
|
80 |
+
Arguments:
|
81 |
+
q, k, v: (batch_size, seqlen, nheads, head_dim)
|
82 |
+
dropout_p: float
|
83 |
+
attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
|
84 |
+
Output:
|
85 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
86 |
+
"""
|
87 |
+
batch_size, seqlen, nheads, d = q.shape
|
88 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
89 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
90 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
91 |
+
# Preallocate attn_weights for `baddbmm`
|
92 |
+
if attn_bias is not None:
|
93 |
+
scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
|
94 |
+
else:
|
95 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
|
96 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
|
97 |
+
'(b h) t s -> b h t s', h=nheads)
|
98 |
+
if causal:
|
99 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
100 |
+
# So we have to construct the mask in float
|
101 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
102 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
103 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
104 |
+
attention = torch.softmax(scores, dim=-1)
|
105 |
+
attention_drop = F.dropout(attention, dropout_p)
|
106 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
107 |
+
return output.to(dtype=q.dtype)
|
108 |
+
|
109 |
+
|
110 |
+
def time_fwd_bwd(func, *args, **kwargs):
|
111 |
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
112 |
+
return time_f[1].mean, time_b[1].mean
|
113 |
+
|
114 |
+
|
115 |
+
repeats = 30
|
116 |
+
device = 'cuda'
|
117 |
+
dtype = torch.float16
|
118 |
+
|
119 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
120 |
+
causal_vals = [False, True]
|
121 |
+
headdim_vals = [64, 128]
|
122 |
+
dim = 2048
|
123 |
+
dropout_p = 0.0
|
124 |
+
|
125 |
+
methods = (["fa2_alibi", "torch"]
|
126 |
+
+ (["xformers"] if xops is not None else [])
|
127 |
+
+ ["sdpa"]
|
128 |
+
+ ["fa2_baseline"]
|
129 |
+
+ ["fa2_rotary"])
|
130 |
+
|
131 |
+
time_f = {}
|
132 |
+
time_b = {}
|
133 |
+
time_f_b = {}
|
134 |
+
speed_f = {}
|
135 |
+
speed_b = {}
|
136 |
+
speed_f_b = {}
|
137 |
+
for causal in causal_vals:
|
138 |
+
for headdim in headdim_vals:
|
139 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
140 |
+
config = (causal, headdim, batch_size, seqlen)
|
141 |
+
nheads = dim // headdim
|
142 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
143 |
+
requires_grad=True) for _ in range(3)]
|
144 |
+
# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
145 |
+
alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
|
146 |
+
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
|
147 |
+
attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
|
148 |
+
f, b = time_fwd_bwd(
|
149 |
+
flash_attn_func,
|
150 |
+
q, k, v,
|
151 |
+
dropout_p,
|
152 |
+
causal=causal,
|
153 |
+
# alibi_slopes=alibi_slopes,
|
154 |
+
alibi_slopes=None,
|
155 |
+
repeats=repeats,
|
156 |
+
verbose=False
|
157 |
+
)
|
158 |
+
time_f[config, "fa2_baseline"] = f
|
159 |
+
time_b[config, "fa2_baseline"] = b
|
160 |
+
|
161 |
+
q = q.detach().requires_grad_(True)
|
162 |
+
k = k.detach().requires_grad_(True)
|
163 |
+
v = v.detach().requires_grad_(True)
|
164 |
+
f, b = time_fwd_bwd(
|
165 |
+
flash_attn_func,
|
166 |
+
q, k, v,
|
167 |
+
dropout_p,
|
168 |
+
causal=causal,
|
169 |
+
alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
|
170 |
+
# alibi_slopes=None,
|
171 |
+
repeats=repeats,
|
172 |
+
verbose=False
|
173 |
+
)
|
174 |
+
time_f[config, "fa2_alibi"] = f
|
175 |
+
time_b[config, "fa2_alibi"] = b
|
176 |
+
|
177 |
+
try:
|
178 |
+
q = q.detach().requires_grad_(True)
|
179 |
+
k = k.detach().requires_grad_(True)
|
180 |
+
v = v.detach().requires_grad_(True)
|
181 |
+
f, b = time_fwd_bwd(
|
182 |
+
attention_pytorch,
|
183 |
+
q, k, v,
|
184 |
+
dropout_p,
|
185 |
+
causal=causal,
|
186 |
+
attn_bias=attn_bias,
|
187 |
+
repeats=repeats,
|
188 |
+
verbose=False
|
189 |
+
)
|
190 |
+
except: # Skip if OOM
|
191 |
+
f, b = float('nan'), float('nan')
|
192 |
+
time_f[config, "torch"] = f
|
193 |
+
time_b[config, "torch"] = b
|
194 |
+
|
195 |
+
# F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
|
196 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False):
|
197 |
+
q_pt = q.detach().requires_grad_(True).transpose(1, 2)
|
198 |
+
k_pt = k.detach().requires_grad_(True).transpose(1, 2)
|
199 |
+
v_pt = v.detach().requires_grad_(True).transpose(1, 2)
|
200 |
+
f, b = time_fwd_bwd(
|
201 |
+
F.scaled_dot_product_attention,
|
202 |
+
q_pt, k_pt, v_pt,
|
203 |
+
attn_mask=attn_bias,
|
204 |
+
dropout_p=dropout_p,
|
205 |
+
is_causal=causal,
|
206 |
+
repeats=repeats,
|
207 |
+
verbose=False
|
208 |
+
)
|
209 |
+
time_f[config, "sdpa"] = f
|
210 |
+
time_b[config, "sdpa"] = b
|
211 |
+
|
212 |
+
if xops is not None:
|
213 |
+
q = q.detach().requires_grad_(True)
|
214 |
+
k = k.detach().requires_grad_(True)
|
215 |
+
v = v.detach().requires_grad_(True)
|
216 |
+
if causal:
|
217 |
+
attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
|
218 |
+
# NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
|
219 |
+
# `flshattB@v2.3.6` is not supported because:
|
220 |
+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
|
221 |
+
# `cutlassB` is not supported because:
|
222 |
+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
|
223 |
+
attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
|
224 |
+
else:
|
225 |
+
attn_bias_xops = attn_bias.to(dtype=q.dtype)
|
226 |
+
f, b = time_fwd_bwd(
|
227 |
+
xops.memory_efficient_attention,
|
228 |
+
q, k, v,
|
229 |
+
attn_bias_xops,
|
230 |
+
dropout_p,
|
231 |
+
repeats=repeats,
|
232 |
+
verbose=False
|
233 |
+
)
|
234 |
+
time_f[config, "xformers"] = f
|
235 |
+
time_b[config, "xformers"] = b
|
236 |
+
|
237 |
+
q = q.detach().requires_grad_(True)
|
238 |
+
k = k.detach().requires_grad_(True)
|
239 |
+
v = v.detach().requires_grad_(True)
|
240 |
+
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
|
241 |
+
f, b = time_fwd_bwd(
|
242 |
+
flash_rotary,
|
243 |
+
q, k, v,
|
244 |
+
cos, sin,
|
245 |
+
causal,
|
246 |
+
repeats=repeats,
|
247 |
+
verbose=False
|
248 |
+
)
|
249 |
+
time_f[config, "fa2_rotary"] = f
|
250 |
+
time_b[config, "fa2_rotary"] = b
|
251 |
+
|
252 |
+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
253 |
+
csv_output = ""
|
254 |
+
csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
|
255 |
+
for method in methods:
|
256 |
+
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
|
257 |
+
speed_f[config, method] = efficiency(
|
258 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
259 |
+
time_f[config, method]
|
260 |
+
)
|
261 |
+
speed_b[config, method] = efficiency(
|
262 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
|
263 |
+
time_b[config, method]
|
264 |
+
)
|
265 |
+
speed_f_b[config, method] = efficiency(
|
266 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
|
267 |
+
time_f_b[config, method]
|
268 |
+
)
|
269 |
+
print(
|
270 |
+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
|
271 |
+
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
|
272 |
+
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
|
273 |
+
)
|
274 |
+
csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
|
275 |
+
print(csv_output)
|
benchmark_attn.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import time
|
8 |
+
|
9 |
+
try:
|
10 |
+
import cudnn
|
11 |
+
except ImportError:
|
12 |
+
cudnn = None
|
13 |
+
|
14 |
+
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
|
17 |
+
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
18 |
+
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
19 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
20 |
+
from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
|
21 |
+
|
22 |
+
# Need to install triton nightly:
|
23 |
+
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
|
24 |
+
|
25 |
+
try:
|
26 |
+
from triton_fused_attention import attention as triton_attention
|
27 |
+
except ImportError:
|
28 |
+
triton_attention = None
|
29 |
+
|
30 |
+
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
|
31 |
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
32 |
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
33 |
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
34 |
+
|
35 |
+
|
36 |
+
def convert_to_cudnn_type(torch_type):
|
37 |
+
if torch_type == torch.float16:
|
38 |
+
return cudnn.data_type.HALF
|
39 |
+
elif torch_type == torch.bfloat16:
|
40 |
+
return cudnn.data_type.BFLOAT16
|
41 |
+
elif torch_type == torch.float32:
|
42 |
+
return cudnn.data_type.FLOAT
|
43 |
+
elif torch_type == torch.int32:
|
44 |
+
return cudnn.data_type.INT32
|
45 |
+
elif torch_type == torch.int64:
|
46 |
+
return cudnn.data_type.INT64
|
47 |
+
else:
|
48 |
+
raise ValueError("Unsupported tensor data type.")
|
49 |
+
|
50 |
+
|
51 |
+
def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
|
52 |
+
b, nheads, seqlen_q, headdim = q.shape
|
53 |
+
_, nheads_kv, seqlen_k, _ = k.shape
|
54 |
+
assert v.shape == (b, nheads_kv, seqlen_k, headdim)
|
55 |
+
assert cudnn is not None, 'CUDNN is not available'
|
56 |
+
q_gpu, k_gpu, v_gpu = q, k, v
|
57 |
+
o_gpu, stats_gpu = o, stats
|
58 |
+
graph_forward = cudnn.pygraph(
|
59 |
+
io_data_type=convert_to_cudnn_type(q.dtype),
|
60 |
+
intermediate_data_type=cudnn.data_type.FLOAT,
|
61 |
+
compute_data_type=cudnn.data_type.FLOAT,
|
62 |
+
)
|
63 |
+
q_forward = graph_forward.tensor_like(q_gpu.detach())
|
64 |
+
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
65 |
+
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
66 |
+
|
67 |
+
seqlens_reshaped = seqlens if varlen else None
|
68 |
+
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
69 |
+
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
70 |
+
|
71 |
+
o_forward, stats_forward = graph_forward.sdpa(
|
72 |
+
name="sdpa",
|
73 |
+
q=q_forward,
|
74 |
+
k=k_forward,
|
75 |
+
v=v_forward,
|
76 |
+
is_inference=False,
|
77 |
+
attn_scale=1.0 / math.sqrt(headdim),
|
78 |
+
use_causal_mask=causal,
|
79 |
+
use_padding_mask=varlen,
|
80 |
+
seq_len_q=seq_len_q,
|
81 |
+
seq_len_kv=seq_len_kv,
|
82 |
+
)
|
83 |
+
|
84 |
+
o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
|
85 |
+
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
|
86 |
+
|
87 |
+
graph_forward.validate()
|
88 |
+
graph_forward.build_operation_graph()
|
89 |
+
graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
90 |
+
graph_forward.check_support()
|
91 |
+
graph_forward.build_plans()
|
92 |
+
|
93 |
+
variant_pack_forward = {
|
94 |
+
q_forward: q_gpu,
|
95 |
+
k_forward: k_gpu,
|
96 |
+
v_forward: v_gpu,
|
97 |
+
o_forward: o_gpu,
|
98 |
+
stats_forward: stats_gpu,
|
99 |
+
seq_len_q: seqlens_reshaped,
|
100 |
+
seq_len_kv: seqlens_reshaped,
|
101 |
+
}
|
102 |
+
|
103 |
+
dQ_gpu = torch.empty_like(q_gpu)
|
104 |
+
dK_gpu = torch.empty_like(k_gpu)
|
105 |
+
dV_gpu = torch.empty_like(v_gpu)
|
106 |
+
dO_gpu = grad
|
107 |
+
|
108 |
+
graph_backward = cudnn.pygraph(
|
109 |
+
io_data_type=cudnn.data_type.HALF,
|
110 |
+
intermediate_data_type=cudnn.data_type.FLOAT,
|
111 |
+
compute_data_type=cudnn.data_type.FLOAT,
|
112 |
+
)
|
113 |
+
|
114 |
+
q_backward = graph_backward.tensor_like(q_gpu.detach())
|
115 |
+
k_backward = graph_backward.tensor_like(k_gpu.detach())
|
116 |
+
v_backward = graph_backward.tensor_like(v_gpu.detach())
|
117 |
+
o_backward = graph_backward.tensor_like(o_gpu.detach())
|
118 |
+
dO_backward = graph_backward.tensor_like(dO_gpu.detach())
|
119 |
+
stats_backward = graph_backward.tensor_like(stats_gpu.detach())
|
120 |
+
seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
121 |
+
seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
122 |
+
|
123 |
+
dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
|
124 |
+
name="sdpa_backward",
|
125 |
+
q=q_backward,
|
126 |
+
k=k_backward,
|
127 |
+
v=v_backward,
|
128 |
+
o=o_backward,
|
129 |
+
dO=dO_backward,
|
130 |
+
stats=stats_backward,
|
131 |
+
attn_scale=1.0 / math.sqrt(headdim),
|
132 |
+
use_causal_mask=causal,
|
133 |
+
use_padding_mask=varlen,
|
134 |
+
seq_len_q=seq_len_q,
|
135 |
+
seq_len_kv=seq_len_kv,
|
136 |
+
)
|
137 |
+
|
138 |
+
dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
|
139 |
+
dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
|
140 |
+
dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
|
141 |
+
|
142 |
+
graph_backward.validate()
|
143 |
+
graph_backward.build_operation_graph()
|
144 |
+
graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
145 |
+
graph_backward.check_support()
|
146 |
+
graph_backward.build_plans()
|
147 |
+
|
148 |
+
variant_pack_backward = {
|
149 |
+
q_backward: q_gpu,
|
150 |
+
k_backward: k_gpu,
|
151 |
+
v_backward: v_gpu,
|
152 |
+
o_backward: o_gpu,
|
153 |
+
dO_backward: dO_gpu,
|
154 |
+
stats_backward: stats_gpu,
|
155 |
+
dQ_backward: dQ_gpu,
|
156 |
+
dK_backward: dK_gpu,
|
157 |
+
dV_backward: dV_gpu,
|
158 |
+
seq_len_q: seqlens_reshaped,
|
159 |
+
seq_len_kv: seqlens_reshaped,
|
160 |
+
}
|
161 |
+
|
162 |
+
workspace = torch.empty(
|
163 |
+
max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
|
164 |
+
device="cuda", dtype=torch.uint8
|
165 |
+
)
|
166 |
+
|
167 |
+
def run_fwd(*args, **kwargs):
|
168 |
+
graph_forward.execute(variant_pack_forward, workspace)
|
169 |
+
return o_gpu, stats_gpu
|
170 |
+
|
171 |
+
def run_bwd(*args, **kwargs):
|
172 |
+
graph_backward.execute(variant_pack_backward, workspace)
|
173 |
+
return dQ_gpu, dK_gpu, dV_gpu
|
174 |
+
|
175 |
+
return run_fwd, run_bwd
|
176 |
+
|
177 |
+
|
178 |
+
torch.manual_seed(0)
|
179 |
+
repeats = 100
|
180 |
+
dropout_p = 0.0
|
181 |
+
causal = False
|
182 |
+
dtype = torch.float16
|
183 |
+
device = 'cuda'
|
184 |
+
verbose = False
|
185 |
+
batch_size = 2
|
186 |
+
# seqlen = 2048
|
187 |
+
seqlen = 8192
|
188 |
+
# seqlen = 4096
|
189 |
+
# seqlen = 2047
|
190 |
+
dim = 2048
|
191 |
+
# headdim = 128
|
192 |
+
# headdim = 64
|
193 |
+
headdim = 256
|
194 |
+
|
195 |
+
for mode in ['fwd', 'bwd']:
|
196 |
+
# for mode in ['bwd']:
|
197 |
+
for headdim in [64, 128, 256]:
|
198 |
+
# for headdim in [128]:
|
199 |
+
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
|
200 |
+
# for seqlen in [8192]:
|
201 |
+
nheads = dim // headdim
|
202 |
+
# nheads = 24
|
203 |
+
# headdim = 64
|
204 |
+
# batch_size = 64
|
205 |
+
# seqlen = 512
|
206 |
+
# nheads = 8
|
207 |
+
# headdim = 128
|
208 |
+
# nheads = 16
|
209 |
+
# headdim = 128
|
210 |
+
nheads_kv = nheads
|
211 |
+
# nheads_kv = 1
|
212 |
+
|
213 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
214 |
+
requires_grad=True)
|
215 |
+
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
216 |
+
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
217 |
+
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
218 |
+
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
|
219 |
+
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
220 |
+
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
221 |
+
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
|
222 |
+
grad_t = grad.transpose(1, 2).contiguous()
|
223 |
+
o_t = torch.empty_like(q.transpose(1, 2))
|
224 |
+
stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
|
225 |
+
|
226 |
+
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
|
227 |
+
|
228 |
+
for causal in [False, True]:
|
229 |
+
# for causal in [True]:
|
230 |
+
print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
|
231 |
+
# For var-seq-len
|
232 |
+
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
233 |
+
seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
|
234 |
+
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
235 |
+
if headdim <= 128 and cudnn is not None:
|
236 |
+
cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
|
237 |
+
cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
|
238 |
+
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
239 |
+
ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
|
240 |
+
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
241 |
+
if mode == 'bwd':
|
242 |
+
ref_dv, v.grad = v.grad.clone(), None
|
243 |
+
ref_dk, k.grad = k.grad.clone(), None
|
244 |
+
ref_dq, q.grad = q.grad.clone(), None
|
245 |
+
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
|
246 |
+
if headdim <= 128:
|
247 |
+
if triton_attention is not None and nheads_kv == nheads:
|
248 |
+
if mode == 'fwd':
|
249 |
+
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
250 |
+
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
|
251 |
+
# TODO: fix Triton numeric errors.
|
252 |
+
# if mode == 'bwd':
|
253 |
+
# dv, v_t.grad = v_t.grad.clone(), None
|
254 |
+
# dk, k_t.grad = k_t.grad.clone(), None
|
255 |
+
# dq, q_t.grad = q_t.grad.clone(), None
|
256 |
+
# torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
257 |
+
# torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
258 |
+
# torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
259 |
+
if cudnn is not None:
|
260 |
+
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
261 |
+
if mode == 'fwd':
|
262 |
+
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
263 |
+
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
264 |
+
cudnn_sdpa_fwd()
|
265 |
+
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
266 |
+
cudnn_sdpa_fwd_varlen()
|
267 |
+
torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
268 |
+
else:
|
269 |
+
cudnn_sdpa_fwd()
|
270 |
+
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
271 |
+
_, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
272 |
+
dq, dk, dv = cudnn_sdpa_bwd()
|
273 |
+
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
274 |
+
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
275 |
+
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
276 |
+
dq, dk, dv = cudnn_sdpa_bwd_varlen()
|
277 |
+
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
278 |
+
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
279 |
+
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
280 |
+
# pytorch_profiler(cudnn_sdpa, backward=False)
|
281 |
+
|
282 |
+
if headdim <= 128 or mode == 'fwd':
|
283 |
+
time.sleep(1)
|
284 |
+
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
|
285 |
+
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
286 |
+
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
287 |
+
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
288 |
+
time.sleep(1)
|
289 |
+
if mode == 'bwd':
|
290 |
+
dv, v.grad = v.grad.clone(), None
|
291 |
+
dk, k.grad = k.grad.clone(), None
|
292 |
+
dq, q.grad = q.grad.clone(), None
|
293 |
+
torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
|
294 |
+
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
|
295 |
+
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
|
296 |
+
|
297 |
+
bench_var_fn = bench_fn
|
298 |
+
if mode == 'bwd':
|
299 |
+
grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
|
300 |
+
bench_var_fn = partial(benchmark_backward, grad=grad_var)
|
301 |
+
_, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
302 |
+
|
303 |
+
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
|
304 |
+
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
|
305 |
+
if headdim <= 128:
|
306 |
+
if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
|
307 |
+
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
308 |
+
if cudnn is not None:
|
309 |
+
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
310 |
+
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
|
311 |
+
if headdim <= 128 or mode == 'fwd':
|
312 |
+
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
313 |
+
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
314 |
+
|
benchmark_causal.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
10 |
+
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
12 |
+
# # from flash_attn.triton.fused_attention import attention as attention
|
13 |
+
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
|
14 |
+
# from flash_attn.flash_attn_triton_og import attention as attention_og
|
15 |
+
|
16 |
+
# from triton.ops.flash_attention import attention as attention_triton
|
17 |
+
|
18 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
19 |
+
|
20 |
+
try:
|
21 |
+
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
|
22 |
+
except ImportError:
|
23 |
+
scaled_upper_triang_masked_softmax = None
|
24 |
+
|
25 |
+
|
26 |
+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
27 |
+
"""
|
28 |
+
Arguments:
|
29 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
30 |
+
dropout_p: float
|
31 |
+
Output:
|
32 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
33 |
+
"""
|
34 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
35 |
+
q, k, v = qkv.unbind(dim=2)
|
36 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
37 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
38 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
39 |
+
# Preallocate attn_weights for `baddbmm`
|
40 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
41 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
42 |
+
'(b h) t s -> b h t s', h=nheads)
|
43 |
+
if causal:
|
44 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
45 |
+
# So we have to construct the mask in float
|
46 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
47 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
48 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
49 |
+
attention = torch.softmax(scores, dim=-1)
|
50 |
+
attention_drop = F.dropout(attention, dropout_p)
|
51 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
52 |
+
return output.to(dtype=qkv.dtype)
|
53 |
+
|
54 |
+
|
55 |
+
def attention_megatron(qkv):
|
56 |
+
"""
|
57 |
+
Arguments:
|
58 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
59 |
+
Output:
|
60 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
61 |
+
"""
|
62 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
63 |
+
q, k, v = qkv.unbind(dim=2)
|
64 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
65 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
66 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
67 |
+
# Preallocate attn_weights for `baddbmm`
|
68 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
69 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
70 |
+
'(b h) t s -> b h t s', h=nheads)
|
71 |
+
attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0)
|
72 |
+
output = torch.einsum('bhts,bshd->bthd', attention, v)
|
73 |
+
return output.to(dtype=qkv.dtype)
|
74 |
+
|
75 |
+
|
76 |
+
torch.manual_seed(0)
|
77 |
+
repeats = 30
|
78 |
+
batch_size = 8
|
79 |
+
seqlen = 2048
|
80 |
+
nheads = 12
|
81 |
+
headdim = 128
|
82 |
+
# nheads = 24
|
83 |
+
# headdim = 64
|
84 |
+
# batch_size = 64
|
85 |
+
# seqlen = 512
|
86 |
+
# nheads = 8
|
87 |
+
# headdim = 128
|
88 |
+
dropout_p = 0.0
|
89 |
+
causal = True
|
90 |
+
dtype = torch.float16
|
91 |
+
device = 'cuda'
|
92 |
+
|
93 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
94 |
+
requires_grad=True)
|
95 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
96 |
+
device=qkv.device)
|
97 |
+
|
98 |
+
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
99 |
+
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
100 |
+
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
|
101 |
+
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
102 |
+
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
|
103 |
+
benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
104 |
+
pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
|
105 |
+
|
106 |
+
# for dropout_p in [0.1, 0.0]:
|
107 |
+
# for causal in [False, True]:
|
108 |
+
# print(f"### {dropout_p = }, {causal = } ###")
|
109 |
+
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
|
110 |
+
|
111 |
+
|
112 |
+
# nheads_k = 2
|
113 |
+
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
114 |
+
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
|
115 |
+
# requires_grad=True)
|
116 |
+
# if fav2_kvpacked_func is not None:
|
117 |
+
# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
118 |
+
# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
|
119 |
+
|
120 |
+
# dropout_p = 0.0
|
121 |
+
# causal = False
|
122 |
+
# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
123 |
+
# repeats=repeats, desc='PyTorch Attention')
|
124 |
+
|
125 |
+
# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
|
126 |
+
# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
|
127 |
+
|
128 |
+
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
129 |
+
# requires_grad=True) for _ in range(3)]
|
130 |
+
# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
|
131 |
+
# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
|
132 |
+
|
133 |
+
# if scaled_upper_triang_masked_softmax is not None:
|
134 |
+
# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
|
135 |
+
|
136 |
+
# from src.ops.fftconv import fftconv_func
|
137 |
+
|
138 |
+
# dim = nheads * headdim
|
139 |
+
# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
|
140 |
+
# k = torch.randn(dim, seqlen, device=device, requires_grad=True)
|
141 |
+
# D = torch.randn(dim, device=device, requires_grad=True)
|
142 |
+
# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
|
143 |
+
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
|
144 |
+
# pytorch_profiler(torch.fft.rfft, u.float())
|
145 |
+
|
146 |
+
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
|
147 |
+
ideal_a100_time = flops / 312 / 1e9
|
148 |
+
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
|
149 |
+
exit(0)
|
150 |
+
|
151 |
+
|
152 |
+
def time_fwd_bwd(func, *args, **kwargs):
|
153 |
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
154 |
+
return time_f[1].mean, time_b[1].mean
|
155 |
+
|
156 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
157 |
+
causal_vals = [False, True]
|
158 |
+
headdim_vals = [64, 128]
|
159 |
+
dim = 2048
|
160 |
+
dropout_p = 0.0
|
161 |
+
|
162 |
+
time_f = {}
|
163 |
+
time_b = {}
|
164 |
+
for causal in causal_vals:
|
165 |
+
for headdim in headdim_vals:
|
166 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
167 |
+
nheads = dim // headdim
|
168 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
169 |
+
requires_grad=True)
|
170 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
171 |
+
device=qkv.device)
|
172 |
+
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
173 |
+
f, b = time_fwd_bwd(
|
174 |
+
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
|
175 |
+
causal=causal, repeats=repeats, verbose=False
|
176 |
+
)
|
177 |
+
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
|
178 |
+
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
|
179 |
+
|
180 |
+
qkv = qkv.detach().requires_grad_(True)
|
181 |
+
f, b = time_fwd_bwd(
|
182 |
+
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
183 |
+
)
|
184 |
+
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
|
185 |
+
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
|
186 |
+
|
187 |
+
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
188 |
+
# requires_grad=True) for _ in range(3)]
|
189 |
+
# # Try both values of sequence_parallel and pick the faster one
|
190 |
+
# f, b = time_fwd_bwd(
|
191 |
+
# attention_triton, q, k, v, causal, headdim**(-0.5),
|
192 |
+
# False, repeats=repeats, verbose=False
|
193 |
+
# )
|
194 |
+
# _, b0 = time_fwd_bwd(
|
195 |
+
# attention_triton, q, k, v, causal, headdim**(-0.5),
|
196 |
+
# True, repeats=repeats, verbose=False
|
197 |
+
# )
|
198 |
+
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
|
199 |
+
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
|
200 |
+
|
201 |
+
if seqlen <= 8 * 1024:
|
202 |
+
qkv = qkv.detach().requires_grad_(True)
|
203 |
+
f, b = time_fwd_bwd(
|
204 |
+
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
f, b = float('nan'), float('nan')
|
208 |
+
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
|
209 |
+
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
|
210 |
+
|
211 |
+
# q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
212 |
+
# requires_grad=True) for _ in range(3)]
|
213 |
+
# import xformers.ops as xops
|
214 |
+
# f, b = time_fwd_bwd(
|
215 |
+
# xops.memory_efficient_attention, q, k, v,
|
216 |
+
# attn_bias=xops.LowerTriangularMask() if causal else None,
|
217 |
+
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
|
218 |
+
# )
|
219 |
+
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
|
220 |
+
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
|
221 |
+
|
222 |
+
|
223 |
+
import pickle
|
224 |
+
with open('flash2_attn_time_h100.plk', 'wb') as fp:
|
225 |
+
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
benchmark_flash_attention.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install the newest triton version with
|
2 |
+
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
|
3 |
+
import pickle
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
12 |
+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
13 |
+
|
14 |
+
from flash_attn import flash_attn_qkvpacked_func
|
15 |
+
|
16 |
+
try:
|
17 |
+
from triton.ops.flash_attention import attention as attention_triton
|
18 |
+
except ImportError:
|
19 |
+
attention_triton = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
import xformers.ops as xops
|
23 |
+
except ImportError:
|
24 |
+
xops = None
|
25 |
+
|
26 |
+
|
27 |
+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
28 |
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
29 |
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
30 |
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
31 |
+
|
32 |
+
def efficiency(flop, time):
|
33 |
+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
34 |
+
|
35 |
+
|
36 |
+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
37 |
+
"""
|
38 |
+
Arguments:
|
39 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
40 |
+
dropout_p: float
|
41 |
+
Output:
|
42 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
43 |
+
"""
|
44 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
45 |
+
q, k, v = qkv.unbind(dim=2)
|
46 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
47 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
48 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
49 |
+
# Preallocate attn_weights for `baddbmm`
|
50 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
51 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
52 |
+
'(b h) t s -> b h t s', h=nheads)
|
53 |
+
if causal:
|
54 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
55 |
+
# So we have to construct the mask in float
|
56 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
57 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
58 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
59 |
+
attention = torch.softmax(scores, dim=-1)
|
60 |
+
attention_drop = F.dropout(attention, dropout_p)
|
61 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
62 |
+
return output.to(dtype=qkv.dtype)
|
63 |
+
|
64 |
+
|
65 |
+
def time_fwd_bwd(func, *args, **kwargs):
|
66 |
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
67 |
+
return time_f[1].mean, time_b[1].mean
|
68 |
+
|
69 |
+
|
70 |
+
repeats = 30
|
71 |
+
device = 'cuda'
|
72 |
+
dtype = torch.float16
|
73 |
+
|
74 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
75 |
+
causal_vals = [False, True]
|
76 |
+
headdim_vals = [64, 128]
|
77 |
+
dim = 2048
|
78 |
+
dropout_p = 0.0
|
79 |
+
|
80 |
+
methods = (["Flash2", "Pytorch"]
|
81 |
+
+ (["Triton"] if attention_triton is not None else [])
|
82 |
+
+ (["xformers.c"] if xops is not None else [])
|
83 |
+
+ (["xformers.f"] if xops is not None else []))
|
84 |
+
|
85 |
+
time_f = {}
|
86 |
+
time_b = {}
|
87 |
+
time_f_b = {}
|
88 |
+
speed_f = {}
|
89 |
+
speed_b = {}
|
90 |
+
speed_f_b = {}
|
91 |
+
for causal in causal_vals:
|
92 |
+
for headdim in headdim_vals:
|
93 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
94 |
+
config = (causal, headdim, batch_size, seqlen)
|
95 |
+
nheads = dim // headdim
|
96 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
97 |
+
requires_grad=True)
|
98 |
+
f, b = time_fwd_bwd(
|
99 |
+
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
100 |
+
)
|
101 |
+
time_f[config, "Flash2"] = f
|
102 |
+
time_b[config, "Flash2"] = b
|
103 |
+
|
104 |
+
try:
|
105 |
+
qkv = qkv.detach().requires_grad_(True)
|
106 |
+
f, b = time_fwd_bwd(
|
107 |
+
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
108 |
+
)
|
109 |
+
except: # Skip if OOM
|
110 |
+
f, b = float('nan'), float('nan')
|
111 |
+
time_f[config, "Pytorch"] = f
|
112 |
+
time_b[config, "Pytorch"] = b
|
113 |
+
|
114 |
+
if attention_triton is not None:
|
115 |
+
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
116 |
+
requires_grad=True) for _ in range(3)]
|
117 |
+
# Try both values of sequence_parallel and pick the faster one
|
118 |
+
try:
|
119 |
+
f, b = time_fwd_bwd(
|
120 |
+
attention_triton, q, k, v, causal, headdim**(-0.5),
|
121 |
+
False, repeats=repeats, verbose=False
|
122 |
+
)
|
123 |
+
except:
|
124 |
+
f, b = float('nan'), float('inf')
|
125 |
+
try:
|
126 |
+
_, b0 = time_fwd_bwd(
|
127 |
+
attention_triton, q, k, v, causal, headdim**(-0.5),
|
128 |
+
True, repeats=repeats, verbose=False
|
129 |
+
)
|
130 |
+
except:
|
131 |
+
b0 = float('inf')
|
132 |
+
time_f[config, "Triton"] = f
|
133 |
+
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
|
134 |
+
|
135 |
+
if xops is not None:
|
136 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
137 |
+
requires_grad=True) for _ in range(3)]
|
138 |
+
f, b = time_fwd_bwd(
|
139 |
+
xops.memory_efficient_attention, q, k, v,
|
140 |
+
attn_bias=xops.LowerTriangularMask() if causal else None,
|
141 |
+
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
|
142 |
+
)
|
143 |
+
time_f[config, "xformers.c"] = f
|
144 |
+
time_b[config, "xformers.c"] = b
|
145 |
+
|
146 |
+
if xops is not None:
|
147 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
148 |
+
requires_grad=True) for _ in range(3)]
|
149 |
+
f, b = time_fwd_bwd(
|
150 |
+
xops.memory_efficient_attention, q, k, v,
|
151 |
+
attn_bias=xops.LowerTriangularMask() if causal else None,
|
152 |
+
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
|
153 |
+
)
|
154 |
+
time_f[config, "xformers.f"] = f
|
155 |
+
time_b[config, "xformers.f"] = b
|
156 |
+
|
157 |
+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
158 |
+
for method in methods:
|
159 |
+
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
|
160 |
+
speed_f[config, method] = efficiency(
|
161 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
162 |
+
time_f[config, method]
|
163 |
+
)
|
164 |
+
speed_b[config, method] = efficiency(
|
165 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
|
166 |
+
time_b[config, method]
|
167 |
+
)
|
168 |
+
speed_f_b[config, method] = efficiency(
|
169 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
|
170 |
+
time_f_b[config, method]
|
171 |
+
)
|
172 |
+
print(
|
173 |
+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
|
174 |
+
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
|
175 |
+
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
# with open('flash2_attn_time.plk', 'wb') as fp:
|
180 |
+
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
benchmark_flash_attention_fp8.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install the newest triton version with
|
2 |
+
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
|
3 |
+
import pickle
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
|
12 |
+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
13 |
+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
14 |
+
|
15 |
+
from flash_attn import flash_attn_qkvpacked_func
|
16 |
+
from flash_attn_interface import flash_attn_func
|
17 |
+
|
18 |
+
try:
|
19 |
+
from triton_fused_attention import attention as attention_triton
|
20 |
+
except ImportError:
|
21 |
+
attention_triton = None
|
22 |
+
|
23 |
+
try:
|
24 |
+
import xformers.ops as xops
|
25 |
+
except ImportError:
|
26 |
+
xops = None
|
27 |
+
|
28 |
+
try:
|
29 |
+
import cudnn
|
30 |
+
except ImportError:
|
31 |
+
cudnn = None
|
32 |
+
|
33 |
+
|
34 |
+
def convert_to_cudnn_type(torch_type):
|
35 |
+
if torch_type == torch.float16:
|
36 |
+
return cudnn.data_type.HALF
|
37 |
+
elif torch_type == torch.bfloat16:
|
38 |
+
return cudnn.data_type.BFLOAT16
|
39 |
+
elif torch_type == torch.float32:
|
40 |
+
return cudnn.data_type.FLOAT
|
41 |
+
elif torch_type == torch.int32:
|
42 |
+
return cudnn.data_type.INT32
|
43 |
+
elif torch_type == torch.int64:
|
44 |
+
return cudnn.data_type.INT64
|
45 |
+
elif torch_type == torch.float8_e4m3fn:
|
46 |
+
return cudnn.data_type.FP8_E4M3
|
47 |
+
elif torch_type == torch.float8_e4m3fn:
|
48 |
+
return cudnn.data_type.FP8_E5M2
|
49 |
+
else:
|
50 |
+
raise ValueError("Unsupported tensor data type.")
|
51 |
+
|
52 |
+
def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
|
53 |
+
b, _, _, nheads, headdim = qkv.shape
|
54 |
+
assert cudnn is not None, 'CUDNN is not available'
|
55 |
+
o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
|
56 |
+
o_gpu_transposed = torch.as_strided(
|
57 |
+
o_gpu,
|
58 |
+
[b, nheads, seqlen_q, headdim],
|
59 |
+
[nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
|
60 |
+
)
|
61 |
+
stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
|
62 |
+
amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
|
63 |
+
amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
|
64 |
+
graph = cudnn.pygraph(
|
65 |
+
io_data_type=convert_to_cudnn_type(qkv.dtype),
|
66 |
+
intermediate_data_type=cudnn.data_type.FLOAT,
|
67 |
+
compute_data_type=cudnn.data_type.FLOAT,
|
68 |
+
)
|
69 |
+
new_q = torch.as_strided(
|
70 |
+
qkv,
|
71 |
+
[b, nheads, seqlen_q, headdim],
|
72 |
+
[seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
|
73 |
+
storage_offset=0,
|
74 |
+
)
|
75 |
+
q = graph.tensor(
|
76 |
+
name = "Q",
|
77 |
+
dim = list(new_q.shape),
|
78 |
+
stride = list(new_q.stride()),
|
79 |
+
data_type=convert_to_cudnn_type(qkv.dtype)
|
80 |
+
)
|
81 |
+
new_k = torch.as_strided(
|
82 |
+
qkv,
|
83 |
+
[b, nheads, seqlen_k, headdim],
|
84 |
+
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
|
85 |
+
storage_offset=nheads * headdim,
|
86 |
+
)
|
87 |
+
k = graph.tensor(
|
88 |
+
name = "K",
|
89 |
+
dim = list(new_k.shape),
|
90 |
+
stride = list(new_k.stride()),
|
91 |
+
data_type=convert_to_cudnn_type(qkv.dtype)
|
92 |
+
)
|
93 |
+
new_v = torch.as_strided(
|
94 |
+
qkv,
|
95 |
+
[b, nheads, seqlen_k, headdim],
|
96 |
+
[seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
|
97 |
+
storage_offset=nheads * headdim * 2,
|
98 |
+
)
|
99 |
+
v = graph.tensor(
|
100 |
+
name = "V",
|
101 |
+
dim = list(new_v.shape),
|
102 |
+
stride = list(new_v.stride()),
|
103 |
+
data_type=convert_to_cudnn_type(qkv.dtype)
|
104 |
+
)
|
105 |
+
|
106 |
+
def get_default_scale_tensor():
|
107 |
+
return graph.tensor(
|
108 |
+
dim = [1, 1, 1, 1],
|
109 |
+
stride = [1, 1, 1, 1],
|
110 |
+
data_type=cudnn.data_type.FLOAT
|
111 |
+
)
|
112 |
+
|
113 |
+
default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
|
114 |
+
descale_q = get_default_scale_tensor()
|
115 |
+
descale_k = get_default_scale_tensor()
|
116 |
+
descale_v = get_default_scale_tensor()
|
117 |
+
descale_s = get_default_scale_tensor()
|
118 |
+
scale_s = get_default_scale_tensor()
|
119 |
+
scale_o = get_default_scale_tensor()
|
120 |
+
|
121 |
+
o, _, amax_s, amax_o = graph.sdpa_fp8(
|
122 |
+
q=q,
|
123 |
+
k=k,
|
124 |
+
v=v,
|
125 |
+
descale_q=descale_q,
|
126 |
+
descale_k=descale_k,
|
127 |
+
descale_v=descale_v,
|
128 |
+
descale_s=descale_s,
|
129 |
+
scale_s=scale_s,
|
130 |
+
scale_o=scale_o,
|
131 |
+
is_inference=True,
|
132 |
+
attn_scale=1.0 / math.sqrt(headdim),
|
133 |
+
use_causal_mask=causal,
|
134 |
+
name="sdpa",
|
135 |
+
)
|
136 |
+
|
137 |
+
o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
|
138 |
+
|
139 |
+
amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
|
140 |
+
amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
|
141 |
+
# stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
|
142 |
+
|
143 |
+
graph.validate()
|
144 |
+
graph.build_operation_graph()
|
145 |
+
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
|
146 |
+
graph.check_support()
|
147 |
+
graph.build_plans()
|
148 |
+
|
149 |
+
variant_pack = {
|
150 |
+
q: new_q,
|
151 |
+
k: new_k,
|
152 |
+
v: new_v,
|
153 |
+
descale_q: default_scale_gpu,
|
154 |
+
descale_k: default_scale_gpu,
|
155 |
+
descale_v: default_scale_gpu,
|
156 |
+
descale_s: default_scale_gpu,
|
157 |
+
scale_s: default_scale_gpu,
|
158 |
+
scale_o: default_scale_gpu,
|
159 |
+
o: o_gpu_transposed,
|
160 |
+
amax_s: amax_s_gpu,
|
161 |
+
amax_o: amax_o_gpu,
|
162 |
+
}
|
163 |
+
|
164 |
+
workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
|
165 |
+
|
166 |
+
def run(*args, **kwargs):
|
167 |
+
graph.execute(variant_pack, workspace)
|
168 |
+
return o_gpu, amax_o_gpu
|
169 |
+
|
170 |
+
return run
|
171 |
+
|
172 |
+
|
173 |
+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
174 |
+
"""
|
175 |
+
Arguments:
|
176 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
177 |
+
dropout_p: float
|
178 |
+
Output:
|
179 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
180 |
+
"""
|
181 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
182 |
+
q, k, v = qkv.unbind(dim=2)
|
183 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
184 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
185 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
186 |
+
# Preallocate attn_weights for `baddbmm`
|
187 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
188 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
189 |
+
'(b h) t s -> b h t s', h=nheads)
|
190 |
+
if causal:
|
191 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
192 |
+
# So we have to construct the mask in float
|
193 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
194 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
195 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
196 |
+
attention = torch.softmax(scores, dim=-1)
|
197 |
+
attention_drop = F.dropout(attention, dropout_p)
|
198 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
199 |
+
return output.to(dtype=qkv.dtype)
|
200 |
+
|
201 |
+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
202 |
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
203 |
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
204 |
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
205 |
+
|
206 |
+
def efficiency(flop, time):
|
207 |
+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
208 |
+
|
209 |
+
def time_fwd(func, *args, **kwargs):
|
210 |
+
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
211 |
+
time_f = benchmark_forward(func, *args, **kwargs)
|
212 |
+
return time_f[1].mean
|
213 |
+
|
214 |
+
|
215 |
+
torch.manual_seed(0)
|
216 |
+
|
217 |
+
repeats = 30
|
218 |
+
device = 'cuda'
|
219 |
+
# dtype = torch.float16
|
220 |
+
dtype = torch.float8_e4m3fn
|
221 |
+
|
222 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
|
223 |
+
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
|
224 |
+
# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)]
|
225 |
+
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
|
226 |
+
causal_vals = [False, True]
|
227 |
+
headdim_vals = [128]
|
228 |
+
dim = 2048
|
229 |
+
# dim = 256
|
230 |
+
dropout_p = 0.0
|
231 |
+
|
232 |
+
methods = (["Pytorch", "Flash3", "cuDNN"]
|
233 |
+
# + (["Triton"] if attention_triton is not None else [])
|
234 |
+
# + (["xformers.c"] if xops is not None else [])
|
235 |
+
# + (["xformers.f"] if xops is not None else [])
|
236 |
+
)
|
237 |
+
|
238 |
+
time_f = {}
|
239 |
+
time_b = {}
|
240 |
+
time_f_b = {}
|
241 |
+
speed_f = {}
|
242 |
+
speed_b = {}
|
243 |
+
speed_f_b = {}
|
244 |
+
for causal in causal_vals:
|
245 |
+
for headdim in headdim_vals:
|
246 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
247 |
+
torch.cuda.empty_cache()
|
248 |
+
config = (causal, headdim, batch_size, seqlen)
|
249 |
+
nheads = dim // headdim
|
250 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)]
|
251 |
+
|
252 |
+
qkv = torch.stack([q, k, v], dim=2)
|
253 |
+
qkv = qkv.to(torch.float16)
|
254 |
+
f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
|
255 |
+
time_f[config, "Pytorch"] = f
|
256 |
+
res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
|
257 |
+
|
258 |
+
if attention_triton is not None:
|
259 |
+
q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
|
260 |
+
k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
|
261 |
+
v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
|
262 |
+
scale = 1 / math.sqrt(headdim)
|
263 |
+
f = time_fwd(
|
264 |
+
attention_triton, q_transposed, k_transposed, v_transposed,
|
265 |
+
causal, scale, repeats=5, verbose=False, desc='Triton'
|
266 |
+
)
|
267 |
+
f = time_fwd(
|
268 |
+
attention_triton, q_transposed, k_transposed, v_transposed,
|
269 |
+
causal, scale, repeats=repeats, verbose=False, desc='Triton'
|
270 |
+
)
|
271 |
+
time_f[config, "Triton"] = f
|
272 |
+
res = attention_triton(
|
273 |
+
q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
|
274 |
+
causal, scale
|
275 |
+
).half().transpose(1, 2)
|
276 |
+
torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
|
277 |
+
|
278 |
+
# out = torch.empty_like(q)
|
279 |
+
q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
|
280 |
+
f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
|
281 |
+
|
282 |
+
# res = flash_attn_func(q, k, v, causal=causal)
|
283 |
+
# torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
|
284 |
+
|
285 |
+
time_f[config, "Flash3"] = f
|
286 |
+
|
287 |
+
if cudnn is not None:
|
288 |
+
qkv_fp8 = qkv.to(dtype)
|
289 |
+
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
290 |
+
f = time_fwd(
|
291 |
+
cudnn_spda_setup(
|
292 |
+
qkv_fp8, seqlen, seqlen,
|
293 |
+
causal=causal
|
294 |
+
),
|
295 |
+
repeats=repeats, verbose=False
|
296 |
+
)
|
297 |
+
time_f[config, "cuDNN"] = f
|
298 |
+
# res, amax_o = cudnn_spda_setup(
|
299 |
+
# qkv_fp8, seqlen, seqlen,
|
300 |
+
# causal=causal
|
301 |
+
# )()
|
302 |
+
# res = res.half()
|
303 |
+
# TODO: CUDNN has numerics issues when
|
304 |
+
# num_heads=16, dim=128, seq_len=1024, batch_size=2
|
305 |
+
# or larger sizes.
|
306 |
+
# res_cpu = res.cpu().reshape(-1)
|
307 |
+
# res_baseline_cpu = res_baseline.cpu().reshape(-1)
|
308 |
+
# print(amax_o)
|
309 |
+
# print(res)
|
310 |
+
# print(res_baseline)
|
311 |
+
# for i in range(len(res_cpu)):
|
312 |
+
# item = res_cpu[i]
|
313 |
+
# item_baseline = res_baseline_cpu[i]
|
314 |
+
# if abs(item - item_baseline) > 0.5:
|
315 |
+
# print(i)
|
316 |
+
# print(item)
|
317 |
+
# print(item_baseline)
|
318 |
+
# torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
|
319 |
+
|
320 |
+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
321 |
+
for method in methods:
|
322 |
+
speed_f[config, method] = efficiency(
|
323 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
324 |
+
time_f[config, method]
|
325 |
+
)
|
326 |
+
#print (time_f[config,method])
|
327 |
+
print(
|
328 |
+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
|
329 |
+
)
|
330 |
+
|
331 |
+
|
332 |
+
# with open('flash3_attn_time.plk', 'wb') as fp:
|
333 |
+
# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
benchmark_gemm.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import torch.utils.benchmark as benchmark
|
4 |
+
|
5 |
+
from triton.testing import do_bench
|
6 |
+
|
7 |
+
|
8 |
+
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
|
9 |
+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
10 |
+
if verbose:
|
11 |
+
print(desc, '- Forward pass')
|
12 |
+
t = benchmark.Timer(
|
13 |
+
stmt='fn(*inputs, **kwinputs)',
|
14 |
+
globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
|
15 |
+
num_threads=torch.get_num_threads(),
|
16 |
+
)
|
17 |
+
m = t.timeit(repeats)
|
18 |
+
if verbose:
|
19 |
+
print(m)
|
20 |
+
return t, m
|
21 |
+
|
22 |
+
|
23 |
+
torch.manual_seed(0)
|
24 |
+
repeats = 30
|
25 |
+
dtype = torch.float16
|
26 |
+
device = 'cuda'
|
27 |
+
verbose = False
|
28 |
+
m, n = 8192, 8192
|
29 |
+
|
30 |
+
tflops_matmul = {}
|
31 |
+
tflops_matmul1 = {}
|
32 |
+
for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
|
33 |
+
a = torch.randn(m, k, device=device, dtype=dtype)
|
34 |
+
b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
|
35 |
+
nFLOPS_matmul = 2 * m * n * k
|
36 |
+
time.sleep(2) # to reduce power throttling
|
37 |
+
timing = benchmark_forward(torch.matmul, a, b, desc='cuBLAS', verbose=verbose, repeats=repeats)[1]
|
38 |
+
tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
|
39 |
+
print(f'[torch.utils.benchmark] cuBLAS, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
|
40 |
+
time.sleep(2) # to reduce power throttling
|
41 |
+
ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
|
42 |
+
tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
|
43 |
+
print(f'[triton.test.do_bench] cuBLAS, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')
|
bert.py
ADDED
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Tri Dao.
|
2 |
+
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
3 |
+
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
+
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
+
|
6 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import re
|
10 |
+
from collections import OrderedDict
|
11 |
+
from collections.abc import Sequence
|
12 |
+
from functools import partial
|
13 |
+
from typing import Any, Mapping
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from einops import rearrange
|
19 |
+
from transformers import BertConfig, PretrainedConfig
|
20 |
+
from transformers.models.bert.modeling_bert import (
|
21 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
22 |
+
BertForPreTrainingOutput,
|
23 |
+
)
|
24 |
+
|
25 |
+
from flash_attn.bert_padding import (
|
26 |
+
index_first_axis,
|
27 |
+
index_first_axis_residual,
|
28 |
+
pad_input,
|
29 |
+
unpad_input,
|
30 |
+
)
|
31 |
+
from flash_attn.modules.block import Block
|
32 |
+
from flash_attn.modules.embedding import BertEmbeddings
|
33 |
+
from flash_attn.modules.mha import MHA
|
34 |
+
from flash_attn.modules.mlp import FusedMLP, Mlp
|
35 |
+
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
36 |
+
|
37 |
+
try:
|
38 |
+
from flash_attn.ops.fused_dense import FusedDense
|
39 |
+
except ImportError:
|
40 |
+
FusedDense = None
|
41 |
+
|
42 |
+
try:
|
43 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
44 |
+
except ImportError:
|
45 |
+
layer_norm_fn = None
|
46 |
+
|
47 |
+
|
48 |
+
try:
|
49 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
50 |
+
except ImportError:
|
51 |
+
CrossEntropyLoss = None
|
52 |
+
|
53 |
+
|
54 |
+
logger = logging.getLogger(__name__)
|
55 |
+
|
56 |
+
|
57 |
+
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
58 |
+
use_flash_attn = getattr(config, "use_flash_attn", False)
|
59 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
60 |
+
rotary_kwargs = {}
|
61 |
+
if config.position_embedding_type == "rotary":
|
62 |
+
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
|
63 |
+
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
64 |
+
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
|
65 |
+
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
|
66 |
+
mixer_cls = partial(
|
67 |
+
MHA,
|
68 |
+
num_heads=config.num_attention_heads,
|
69 |
+
cross_attn=cross_attn,
|
70 |
+
dropout=config.attention_probs_dropout_prob,
|
71 |
+
causal=False,
|
72 |
+
fused_bias_fc=fused_bias_fc,
|
73 |
+
use_flash_attn=use_flash_attn,
|
74 |
+
return_residual=return_residual,
|
75 |
+
**rotary_kwargs,
|
76 |
+
)
|
77 |
+
return mixer_cls
|
78 |
+
|
79 |
+
|
80 |
+
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
81 |
+
inner_dim = config.intermediate_size
|
82 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
83 |
+
if fused_mlp:
|
84 |
+
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
85 |
+
"fused_mlp only " "supports approximate gelu"
|
86 |
+
)
|
87 |
+
if not fused_mlp:
|
88 |
+
approximate = (
|
89 |
+
"tanh"
|
90 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
91 |
+
else "none"
|
92 |
+
)
|
93 |
+
mlp_cls = partial(
|
94 |
+
Mlp,
|
95 |
+
hidden_features=inner_dim,
|
96 |
+
activation=partial(F.gelu, approximate=approximate),
|
97 |
+
return_residual=return_residual,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
if FusedMLP is None:
|
101 |
+
raise ImportError("fused_dense is not installed")
|
102 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
103 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
104 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
105 |
+
assert layer_idx is not None
|
106 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
107 |
+
mlp_cls = partial(
|
108 |
+
FusedMLP,
|
109 |
+
hidden_features=inner_dim,
|
110 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
111 |
+
return_residual=return_residual,
|
112 |
+
)
|
113 |
+
return mlp_cls
|
114 |
+
|
115 |
+
|
116 |
+
def create_block(config, layer_idx=None):
|
117 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
118 |
+
cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
119 |
+
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
120 |
+
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
121 |
+
# one layer) so we just choose not to return residual in this case.
|
122 |
+
return_residual = not cross_attn
|
123 |
+
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
124 |
+
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
125 |
+
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
126 |
+
block = Block(
|
127 |
+
config.hidden_size,
|
128 |
+
mixer_cls,
|
129 |
+
mlp_cls,
|
130 |
+
norm_cls=norm_cls,
|
131 |
+
prenorm=False,
|
132 |
+
resid_dropout1=config.hidden_dropout_prob,
|
133 |
+
resid_dropout2=config.hidden_dropout_prob,
|
134 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
135 |
+
return_residual=return_residual,
|
136 |
+
)
|
137 |
+
return block
|
138 |
+
|
139 |
+
|
140 |
+
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
141 |
+
def _init_weights(module, initializer_range=0.02):
|
142 |
+
if isinstance(module, nn.Linear):
|
143 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
144 |
+
if module.bias is not None:
|
145 |
+
nn.init.zeros_(module.bias)
|
146 |
+
elif isinstance(module, nn.Embedding):
|
147 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
148 |
+
if module.padding_idx is not None:
|
149 |
+
nn.init.zeros_(module.weight[module.padding_idx])
|
150 |
+
|
151 |
+
|
152 |
+
class BertEncoder(nn.Module):
|
153 |
+
def __init__(self, config: BertConfig):
|
154 |
+
super().__init__()
|
155 |
+
self.use_flash_attn = getattr(config, "use_flash_attn", False)
|
156 |
+
self.layers = nn.ModuleList(
|
157 |
+
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
158 |
+
)
|
159 |
+
|
160 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
161 |
+
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
162 |
+
This means that we only compute the last layer output for these tokens.
|
163 |
+
subset_mask: (batch, seqlen), dtype=torch.bool
|
164 |
+
"""
|
165 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
166 |
+
mixer_kwargs = (
|
167 |
+
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
168 |
+
)
|
169 |
+
for layer in self.layers:
|
170 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
171 |
+
if subset_mask is not None:
|
172 |
+
hidden_states = hidden_states[subset_mask]
|
173 |
+
else:
|
174 |
+
batch, seqlen = hidden_states.shape[:2]
|
175 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
176 |
+
hidden_states, key_padding_mask
|
177 |
+
)
|
178 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
179 |
+
if subset_mask is None:
|
180 |
+
for layer in self.layers:
|
181 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
182 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
183 |
+
else:
|
184 |
+
for layer in self.layers[:-1]:
|
185 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
186 |
+
if key_padding_mask is not None:
|
187 |
+
subset_idx = torch.nonzero(
|
188 |
+
subset_mask[key_padding_mask], as_tuple=False
|
189 |
+
).flatten()
|
190 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
191 |
+
subset_cu_seqlens = F.pad(
|
192 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
196 |
+
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
197 |
+
subset_cu_seqlens = F.pad(
|
198 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
199 |
+
)
|
200 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
201 |
+
hidden_states, subset_idx
|
202 |
+
)
|
203 |
+
# It's ok to set max_seqlen_q to be much larger
|
204 |
+
mixer_kwargs = {
|
205 |
+
"x_kv": hidden_states,
|
206 |
+
"cu_seqlens": subset_cu_seqlens,
|
207 |
+
"max_seqlen": max_seqlen_in_batch,
|
208 |
+
"cu_seqlens_k": cu_seqlens,
|
209 |
+
"max_seqlen_k": max_seqlen_in_batch,
|
210 |
+
}
|
211 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
212 |
+
return hidden_states
|
213 |
+
|
214 |
+
|
215 |
+
class BertPooler(nn.Module):
|
216 |
+
def __init__(self, config):
|
217 |
+
super().__init__()
|
218 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
219 |
+
if fused_bias_fc and FusedDense is None:
|
220 |
+
raise ImportError("fused_dense is not installed")
|
221 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
222 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
223 |
+
self.activation = nn.Tanh()
|
224 |
+
|
225 |
+
def forward(self, hidden_states, pool=True):
|
226 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
227 |
+
# to the first token.
|
228 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
229 |
+
pooled_output = self.dense(first_token_tensor)
|
230 |
+
pooled_output = self.activation(pooled_output)
|
231 |
+
return pooled_output
|
232 |
+
|
233 |
+
|
234 |
+
class BertPredictionHeadTransform(nn.Module):
|
235 |
+
def __init__(self, config):
|
236 |
+
super().__init__()
|
237 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
238 |
+
if fused_bias_fc and FusedDense is None:
|
239 |
+
raise ImportError("fused_dense is not installed")
|
240 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
241 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
242 |
+
raise ImportError("Triton is not installed")
|
243 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
244 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
245 |
+
approximate = (
|
246 |
+
"tanh"
|
247 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
248 |
+
else "none"
|
249 |
+
)
|
250 |
+
self.transform_act_fn = nn.GELU(approximate=approximate)
|
251 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
252 |
+
|
253 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
254 |
+
hidden_states = self.dense(hidden_states)
|
255 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
256 |
+
if not self.fused_dropout_add_ln:
|
257 |
+
hidden_states = self.layer_norm(hidden_states)
|
258 |
+
else:
|
259 |
+
hidden_states = layer_norm_fn(
|
260 |
+
hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
|
261 |
+
)
|
262 |
+
return hidden_states
|
263 |
+
|
264 |
+
|
265 |
+
class BertLMPredictionHead(nn.Module):
|
266 |
+
def __init__(self, config):
|
267 |
+
super().__init__()
|
268 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
269 |
+
if fused_bias_fc and FusedDense is None:
|
270 |
+
raise ImportError("fused_dense is not installed")
|
271 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
272 |
+
|
273 |
+
self.transform = BertPredictionHeadTransform(config)
|
274 |
+
|
275 |
+
# The output weights are the same as the input embeddings, but there is
|
276 |
+
# an output-only bias for each token.
|
277 |
+
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
278 |
+
|
279 |
+
def forward(self, hidden_states):
|
280 |
+
hidden_states = self.transform(hidden_states)
|
281 |
+
hidden_states = self.decoder(hidden_states)
|
282 |
+
return hidden_states
|
283 |
+
|
284 |
+
|
285 |
+
class BertPreTrainingHeads(nn.Module):
|
286 |
+
def __init__(self, config):
|
287 |
+
super().__init__()
|
288 |
+
self.predictions = BertLMPredictionHead(config)
|
289 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
290 |
+
|
291 |
+
def forward(self, sequence_output, pooled_output):
|
292 |
+
prediction_scores = self.predictions(sequence_output)
|
293 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
294 |
+
return prediction_scores, seq_relationship_score
|
295 |
+
|
296 |
+
|
297 |
+
class BertPreTrainedModel(nn.Module):
|
298 |
+
"""An abstract class to handle weights initialization and
|
299 |
+
a simple interface for dowloading and loading pretrained models.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, config, *inputs, **kwargs):
|
303 |
+
super().__init__()
|
304 |
+
if not isinstance(config, BertConfig):
|
305 |
+
raise ValueError(
|
306 |
+
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
307 |
+
"To create a model from a Google pretrained model use "
|
308 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
309 |
+
self.__class__.__name__, self.__class__.__name__
|
310 |
+
)
|
311 |
+
)
|
312 |
+
self.config = config
|
313 |
+
|
314 |
+
@classmethod
|
315 |
+
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
316 |
+
"""
|
317 |
+
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
318 |
+
Download and cache the pre-trained model file if needed.
|
319 |
+
|
320 |
+
Params:
|
321 |
+
pretrained_model_name_or_path: either:
|
322 |
+
- a path or url to a pretrained model archive containing:
|
323 |
+
. `bert_config.json` a configuration file for the model
|
324 |
+
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
|
325 |
+
- a path or url to a pretrained model archive containing:
|
326 |
+
. `bert_config.json` a configuration file for the model
|
327 |
+
. `model.chkpt` a TensorFlow checkpoint
|
328 |
+
*inputs, **kwargs: additional input for the specific Bert class
|
329 |
+
(ex: num_labels for BertForSequenceClassification)
|
330 |
+
"""
|
331 |
+
# Instantiate model.
|
332 |
+
model = cls(config, *inputs, **kwargs)
|
333 |
+
load_return = model.load_state_dict(
|
334 |
+
remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
|
335 |
+
)
|
336 |
+
logger.info(load_return)
|
337 |
+
return model
|
338 |
+
|
339 |
+
|
340 |
+
class BertModel(BertPreTrainedModel):
|
341 |
+
def __init__(self, config: BertConfig, add_pooling_layer=True):
|
342 |
+
super().__init__(config)
|
343 |
+
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
344 |
+
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
345 |
+
config.vocab_size += self.pad_vocab_size_multiple - (
|
346 |
+
config.vocab_size % self.pad_vocab_size_multiple
|
347 |
+
)
|
348 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
349 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
350 |
+
raise ImportError("Triton is not installed")
|
351 |
+
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
352 |
+
|
353 |
+
self.embeddings = BertEmbeddings(
|
354 |
+
config.hidden_size,
|
355 |
+
config.vocab_size,
|
356 |
+
config.max_position_embeddings,
|
357 |
+
config.type_vocab_size,
|
358 |
+
padding_idx=config.pad_token_id,
|
359 |
+
)
|
360 |
+
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
361 |
+
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
362 |
+
self.encoder = BertEncoder(config)
|
363 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
364 |
+
|
365 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
366 |
+
|
367 |
+
def forward(
|
368 |
+
self,
|
369 |
+
input_ids,
|
370 |
+
position_ids=None,
|
371 |
+
token_type_ids=None,
|
372 |
+
attention_mask=None,
|
373 |
+
masked_tokens_mask=None,
|
374 |
+
):
|
375 |
+
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
376 |
+
we only want the output for the masked tokens. This means that we only compute the last
|
377 |
+
layer output for these tokens.
|
378 |
+
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
379 |
+
"""
|
380 |
+
hidden_states = self.embeddings(
|
381 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
382 |
+
)
|
383 |
+
# TD [2022-12:18]: Don't need to force residual in fp32
|
384 |
+
# BERT puts embedding LayerNorm before embedding dropout.
|
385 |
+
if not self.fused_dropout_add_ln:
|
386 |
+
hidden_states = self.emb_ln(hidden_states)
|
387 |
+
else:
|
388 |
+
hidden_states = layer_norm_fn(
|
389 |
+
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
|
390 |
+
)
|
391 |
+
hidden_states = self.emb_drop(hidden_states)
|
392 |
+
|
393 |
+
if masked_tokens_mask is not None:
|
394 |
+
batch_size, seqlen = input_ids.shape[:2]
|
395 |
+
# We also need the first column for the CLS token
|
396 |
+
first_col_mask = torch.zeros(
|
397 |
+
batch_size, seqlen, dtype=torch.bool, device=input_ids.device
|
398 |
+
)
|
399 |
+
first_col_mask[:, 0] = True
|
400 |
+
subset_mask = masked_tokens_mask | first_col_mask
|
401 |
+
else:
|
402 |
+
subset_mask = None
|
403 |
+
|
404 |
+
sequence_output = self.encoder(
|
405 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
406 |
+
)
|
407 |
+
|
408 |
+
if masked_tokens_mask is None:
|
409 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
410 |
+
else:
|
411 |
+
# TD [2022-03-01]: the indexing here is very tricky.
|
412 |
+
if attention_mask is not None:
|
413 |
+
subset_idx = subset_mask[attention_mask]
|
414 |
+
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
415 |
+
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
|
416 |
+
else:
|
417 |
+
pool_input = sequence_output[first_col_mask[subset_mask]]
|
418 |
+
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
419 |
+
pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
420 |
+
|
421 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
422 |
+
last_hidden_state=sequence_output,
|
423 |
+
pooler_output=pooled_output,
|
424 |
+
)
|
425 |
+
|
426 |
+
|
427 |
+
class BertForPreTraining(BertPreTrainedModel):
|
428 |
+
def __init__(self, config: BertConfig):
|
429 |
+
super().__init__(config)
|
430 |
+
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
431 |
+
# (around 15%) to the classifier heads.
|
432 |
+
self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
433 |
+
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
434 |
+
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
435 |
+
self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
436 |
+
if self.last_layer_subset:
|
437 |
+
assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
438 |
+
use_xentropy = getattr(config, "use_xentropy", False)
|
439 |
+
if use_xentropy and CrossEntropyLoss is None:
|
440 |
+
raise ImportError("xentropy_cuda is not installed")
|
441 |
+
loss_cls = (
|
442 |
+
nn.CrossEntropyLoss
|
443 |
+
if not use_xentropy
|
444 |
+
else partial(CrossEntropyLoss, inplace_backward=True)
|
445 |
+
)
|
446 |
+
|
447 |
+
self.bert = BertModel(config)
|
448 |
+
self.cls = BertPreTrainingHeads(config)
|
449 |
+
self.mlm_loss = loss_cls(ignore_index=0)
|
450 |
+
self.nsp_loss = loss_cls(ignore_index=-1)
|
451 |
+
|
452 |
+
# Initialize weights and apply final processing
|
453 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
454 |
+
self.tie_weights()
|
455 |
+
|
456 |
+
def tie_weights(self):
|
457 |
+
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
458 |
+
|
459 |
+
def forward(
|
460 |
+
self,
|
461 |
+
input_ids,
|
462 |
+
position_ids=None,
|
463 |
+
token_type_ids=None,
|
464 |
+
attention_mask=None,
|
465 |
+
labels=None,
|
466 |
+
next_sentence_label=None,
|
467 |
+
):
|
468 |
+
"""
|
469 |
+
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
470 |
+
mask).
|
471 |
+
Outputs:
|
472 |
+
if `labels` and `next_sentence_label` are not `None`:
|
473 |
+
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
474 |
+
sentence classification loss.
|
475 |
+
if `labels` or `next_sentence_label` is `None`:
|
476 |
+
Outputs a tuple comprising
|
477 |
+
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
478 |
+
- the next sentence classification logits of shape [batch_size, 2].
|
479 |
+
|
480 |
+
"""
|
481 |
+
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
482 |
+
outputs = self.bert(
|
483 |
+
input_ids,
|
484 |
+
position_ids=position_ids,
|
485 |
+
token_type_ids=token_type_ids,
|
486 |
+
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
487 |
+
masked_tokens_mask=masked_tokens_mask,
|
488 |
+
)
|
489 |
+
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
490 |
+
if self.dense_seq_output and labels is not None:
|
491 |
+
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
492 |
+
if not self.last_layer_subset:
|
493 |
+
sequence_output = index_first_axis(
|
494 |
+
rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
495 |
+
)
|
496 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
497 |
+
|
498 |
+
total_loss = None
|
499 |
+
if labels is not None and next_sentence_label is not None:
|
500 |
+
if (
|
501 |
+
self.dense_seq_output and labels is not None
|
502 |
+
): # prediction_scores are already flattened
|
503 |
+
masked_lm_loss = self.mlm_loss(
|
504 |
+
prediction_scores, labels.flatten()[masked_token_idx]
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
masked_lm_loss = self.mlm_loss(
|
508 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
509 |
+
rearrange(labels, "... -> (...)"),
|
510 |
+
)
|
511 |
+
next_sentence_loss = self.nsp_loss(
|
512 |
+
rearrange(seq_relationship_score, "... t -> (...) t"),
|
513 |
+
rearrange(next_sentence_label, "... -> (...)"),
|
514 |
+
)
|
515 |
+
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
516 |
+
|
517 |
+
return BertForPreTrainingOutput(
|
518 |
+
loss=total_loss,
|
519 |
+
prediction_logits=prediction_scores,
|
520 |
+
seq_relationship_logits=seq_relationship_score,
|
521 |
+
)
|
522 |
+
|
523 |
+
|
524 |
+
def remap_state_dict(state_dict, config: PretrainedConfig):
|
525 |
+
"""
|
526 |
+
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
527 |
+
"""
|
528 |
+
|
529 |
+
# LayerNorm
|
530 |
+
def key_mapping_ln_gamma_beta(key):
|
531 |
+
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
532 |
+
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
533 |
+
return key
|
534 |
+
|
535 |
+
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
536 |
+
|
537 |
+
# Layers
|
538 |
+
def key_mapping_layers(key):
|
539 |
+
return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
|
540 |
+
|
541 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
542 |
+
|
543 |
+
# LayerNorm
|
544 |
+
def key_mapping_ln(key):
|
545 |
+
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
|
546 |
+
key = re.sub(
|
547 |
+
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
548 |
+
r"bert.encoder.layers.\1.norm1.\2",
|
549 |
+
key,
|
550 |
+
)
|
551 |
+
key = re.sub(
|
552 |
+
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
553 |
+
r"bert.encoder.layers.\1.norm2.\2",
|
554 |
+
key,
|
555 |
+
)
|
556 |
+
key = re.sub(
|
557 |
+
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
558 |
+
r"cls.predictions.transform.layer_norm.\1",
|
559 |
+
key,
|
560 |
+
)
|
561 |
+
return key
|
562 |
+
|
563 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
564 |
+
|
565 |
+
# MLP
|
566 |
+
def key_mapping_mlp(key):
|
567 |
+
key = re.sub(
|
568 |
+
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
569 |
+
r"bert.encoder.layers.\1.mlp.fc1.\2",
|
570 |
+
key,
|
571 |
+
)
|
572 |
+
key = re.sub(
|
573 |
+
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
|
574 |
+
r"bert.encoder.layers.\1.mlp.fc2.\2",
|
575 |
+
key,
|
576 |
+
)
|
577 |
+
return key
|
578 |
+
|
579 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
580 |
+
|
581 |
+
# Attention
|
582 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
583 |
+
for d in range(config.num_hidden_layers):
|
584 |
+
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
|
585 |
+
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
|
586 |
+
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
|
587 |
+
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
|
588 |
+
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
589 |
+
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
590 |
+
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
591 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
592 |
+
[Wq, Wk, Wv], dim=0
|
593 |
+
)
|
594 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
595 |
+
else:
|
596 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
597 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
598 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
599 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
600 |
+
|
601 |
+
def key_mapping_attn(key):
|
602 |
+
return re.sub(
|
603 |
+
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
604 |
+
r"bert.encoder.layers.\1.mixer.out_proj.\2",
|
605 |
+
key,
|
606 |
+
)
|
607 |
+
|
608 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
609 |
+
|
610 |
+
def key_mapping_decoder_bias(key):
|
611 |
+
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
612 |
+
|
613 |
+
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
614 |
+
|
615 |
+
# Word embedding
|
616 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
617 |
+
if pad_vocab_size_multiple > 1:
|
618 |
+
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
619 |
+
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
620 |
+
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
621 |
+
)
|
622 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
623 |
+
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
624 |
+
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
625 |
+
)
|
626 |
+
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
627 |
+
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
628 |
+
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
629 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
630 |
+
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
631 |
+
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
632 |
+
)
|
633 |
+
|
634 |
+
return state_dict
|
635 |
+
|
636 |
+
|
637 |
+
def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
638 |
+
"""
|
639 |
+
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
|
640 |
+
|
641 |
+
This function is meant to be the inverse of remap_state_dict.
|
642 |
+
"""
|
643 |
+
# Word embedding
|
644 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
645 |
+
if pad_vocab_size_multiple > 1:
|
646 |
+
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
647 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
648 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
649 |
+
# unpad embeddings
|
650 |
+
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
651 |
+
: config.orig_vocab_size, :
|
652 |
+
]
|
653 |
+
state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
|
654 |
+
state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
|
655 |
+
|
656 |
+
for d in range(config.num_hidden_layers):
|
657 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
658 |
+
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
659 |
+
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
660 |
+
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
661 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
|
662 |
+
: Wqkv_weights.shape[0] // 3, :
|
663 |
+
]
|
664 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
|
665 |
+
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
666 |
+
]
|
667 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
|
668 |
+
2 * Wqkv_weights.shape[0] // 3 :, :
|
669 |
+
]
|
670 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
|
671 |
+
: Wqkv_biases.shape[0] // 3
|
672 |
+
]
|
673 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
|
674 |
+
Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
|
675 |
+
]
|
676 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
|
677 |
+
2 * Wqkv_biases.shape[0] // 3 :
|
678 |
+
]
|
679 |
+
else:
|
680 |
+
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
681 |
+
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
682 |
+
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
683 |
+
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
684 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
|
685 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
|
686 |
+
: Wkv_weights.shape[0] // 2, :
|
687 |
+
]
|
688 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
|
689 |
+
Wkv_weights.shape[0] // 2 :, :
|
690 |
+
]
|
691 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
692 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
693 |
+
: Wkv_biases.shape[0] // 2
|
694 |
+
]
|
695 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
|
696 |
+
Wkv_biases.shape[0] // 2 :
|
697 |
+
]
|
698 |
+
|
699 |
+
def inv_key_mapping_ln(key):
|
700 |
+
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
701 |
+
key = re.sub(
|
702 |
+
r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
|
703 |
+
r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
|
704 |
+
key,
|
705 |
+
)
|
706 |
+
key = re.sub(
|
707 |
+
r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
|
708 |
+
r"bert.encoder.layers.\1.output.LayerNorm.\2",
|
709 |
+
key,
|
710 |
+
)
|
711 |
+
key = re.sub(
|
712 |
+
r"cls.predictions.transform.layer_norm.(weight|bias)",
|
713 |
+
r"cls.predictions.transform.LayerNorm.\1",
|
714 |
+
key,
|
715 |
+
)
|
716 |
+
return key
|
717 |
+
|
718 |
+
def inv_key_mapping_ln_gamma_beta(key):
|
719 |
+
key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
|
720 |
+
key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
|
721 |
+
return key
|
722 |
+
|
723 |
+
def inv_key_mapping_layers(key):
|
724 |
+
return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
|
725 |
+
|
726 |
+
def inv_key_mapping_mlp(key):
|
727 |
+
key = re.sub(
|
728 |
+
r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
|
729 |
+
r"bert.encoder.layer.\1.intermediate.dense.\2",
|
730 |
+
key,
|
731 |
+
)
|
732 |
+
key = re.sub(
|
733 |
+
r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
|
734 |
+
r"bert.encoder.layer.\1.output.dense.\2",
|
735 |
+
key,
|
736 |
+
)
|
737 |
+
return key
|
738 |
+
|
739 |
+
def inv_key_mapping_attn(key):
|
740 |
+
return re.sub(
|
741 |
+
r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
|
742 |
+
r"bert.encoder.layer.\1.attention.output.dense.\2",
|
743 |
+
key,
|
744 |
+
)
|
745 |
+
|
746 |
+
def inv_key_mapping_decoder_bias(key):
|
747 |
+
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
748 |
+
|
749 |
+
state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
|
750 |
+
state_dict = OrderedDict(
|
751 |
+
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
752 |
+
)
|
753 |
+
state_dict = OrderedDict(
|
754 |
+
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
755 |
+
)
|
756 |
+
state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
|
757 |
+
state_dict = OrderedDict(
|
758 |
+
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
759 |
+
)
|
760 |
+
state_dict = OrderedDict(
|
761 |
+
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
762 |
+
)
|
763 |
+
|
764 |
+
return state_dict
|
bert_padding.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
|
8 |
+
class IndexFirstAxis(torch.autograd.Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, input, indices):
|
11 |
+
ctx.save_for_backward(indices)
|
12 |
+
assert input.ndim >= 2
|
13 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
14 |
+
second_dim = other_shape.numel()
|
15 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
16 |
+
# return input[indices]
|
17 |
+
return torch.gather(
|
18 |
+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
19 |
+
).reshape(-1, *other_shape)
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def backward(ctx, grad_output):
|
23 |
+
(indices,) = ctx.saved_tensors
|
24 |
+
assert grad_output.ndim >= 2
|
25 |
+
other_shape = grad_output.shape[1:]
|
26 |
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
27 |
+
grad_input = torch.zeros(
|
28 |
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
29 |
+
device=grad_output.device,
|
30 |
+
dtype=grad_output.dtype,
|
31 |
+
)
|
32 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
33 |
+
# grad_input[indices] = grad_output
|
34 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
35 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
36 |
+
|
37 |
+
|
38 |
+
index_first_axis = IndexFirstAxis.apply
|
39 |
+
|
40 |
+
|
41 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
42 |
+
@staticmethod
|
43 |
+
def forward(ctx, values, indices, first_axis_dim):
|
44 |
+
ctx.save_for_backward(indices)
|
45 |
+
assert indices.ndim == 1
|
46 |
+
assert values.ndim >= 2
|
47 |
+
output = torch.zeros(
|
48 |
+
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
49 |
+
)
|
50 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
51 |
+
output[indices] = values
|
52 |
+
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
53 |
+
return output
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def backward(ctx, grad_output):
|
57 |
+
(indices,) = ctx.saved_tensors
|
58 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
59 |
+
grad_values = grad_output[indices]
|
60 |
+
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
61 |
+
return grad_values, None, None
|
62 |
+
|
63 |
+
|
64 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
65 |
+
|
66 |
+
|
67 |
+
class IndexFirstAxisResidual(torch.autograd.Function):
|
68 |
+
@staticmethod
|
69 |
+
def forward(ctx, input, indices):
|
70 |
+
ctx.save_for_backward(indices)
|
71 |
+
assert input.ndim >= 2
|
72 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
73 |
+
second_dim = other_shape.numel()
|
74 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
75 |
+
output = input[indices]
|
76 |
+
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
77 |
+
# memory format to channel_first. In other words, input might not be contiguous.
|
78 |
+
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
|
79 |
+
return output, input.detach()
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def backward(ctx, grad_output, grad_residual):
|
83 |
+
(indices,) = ctx.saved_tensors
|
84 |
+
assert grad_output.ndim >= 2
|
85 |
+
other_shape = grad_output.shape[1:]
|
86 |
+
assert grad_residual.shape[1:] == other_shape
|
87 |
+
grad_input = grad_residual
|
88 |
+
# grad_input[indices] += grad_output
|
89 |
+
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
|
90 |
+
indices = indices.expand_as(grad_output)
|
91 |
+
grad_input.scatter_add_(0, indices, grad_output)
|
92 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
93 |
+
|
94 |
+
|
95 |
+
index_first_axis_residual = IndexFirstAxisResidual.apply
|
96 |
+
|
97 |
+
|
98 |
+
def unpad_input(hidden_states, attention_mask):
|
99 |
+
"""
|
100 |
+
Arguments:
|
101 |
+
hidden_states: (batch, seqlen, ...)
|
102 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
103 |
+
Return:
|
104 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
105 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
106 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
107 |
+
max_seqlen_in_batch: int
|
108 |
+
"""
|
109 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
110 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
111 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
112 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
113 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
114 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
115 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
116 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
117 |
+
# so we write custom forward and backward to make it a bit faster.
|
118 |
+
return (
|
119 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
120 |
+
indices,
|
121 |
+
cu_seqlens,
|
122 |
+
max_seqlen_in_batch,
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
|
127 |
+
"""
|
128 |
+
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
|
129 |
+
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
|
130 |
+
|
131 |
+
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
132 |
+
```
|
133 |
+
[
|
134 |
+
[2, 3, 0, 0, 0, 0],
|
135 |
+
[3, 2, 0, 0, 0, 0],
|
136 |
+
[6, 0, 0, 0, 0, 0]
|
137 |
+
]
|
138 |
+
```
|
139 |
+
, which refers to the 3D-attention mask:
|
140 |
+
```
|
141 |
+
[
|
142 |
+
[
|
143 |
+
[1, 0, 0, 0, 0, 0],
|
144 |
+
[1, 1, 0, 0, 0, 0],
|
145 |
+
[0, 0, 1, 0, 0, 0],
|
146 |
+
[0, 0, 1, 1, 0, 0],
|
147 |
+
[0, 0, 1, 1, 1, 0],
|
148 |
+
[0, 0, 0, 0, 0, 1]
|
149 |
+
],
|
150 |
+
[
|
151 |
+
[1, 0, 0, 0, 0, 0],
|
152 |
+
[1, 1, 0, 0, 0, 0],
|
153 |
+
[1, 1, 1, 0, 0, 0],
|
154 |
+
[0, 0, 0, 1, 0, 0],
|
155 |
+
[0, 0, 0, 1, 1, 0],
|
156 |
+
[0, 0, 0, 0, 0, 1]
|
157 |
+
],
|
158 |
+
[
|
159 |
+
[1, 0, 0, 0, 0, 0],
|
160 |
+
[1, 1, 0, 0, 0, 0],
|
161 |
+
[1, 1, 1, 0, 0, 0],
|
162 |
+
[1, 1, 1, 1, 0, 0],
|
163 |
+
[1, 1, 1, 1, 1, 0],
|
164 |
+
[1, 1, 1, 1, 1, 1]
|
165 |
+
]
|
166 |
+
]
|
167 |
+
```.
|
168 |
+
|
169 |
+
Arguments:
|
170 |
+
hidden_states: (batch, seqlen, ...)
|
171 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
|
172 |
+
Return:
|
173 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
174 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
175 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
176 |
+
max_seqlen_in_batch: int
|
177 |
+
"""
|
178 |
+
length = attention_mask_in_length.sum(dim=-1)
|
179 |
+
seqlen = attention_mask_in_length.size(-1)
|
180 |
+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
|
181 |
+
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
|
182 |
+
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
183 |
+
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
184 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
185 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
186 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
187 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
188 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
189 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
190 |
+
# so we write custom forward and backward to make it a bit faster.
|
191 |
+
return (
|
192 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
193 |
+
indices,
|
194 |
+
cu_seqlens,
|
195 |
+
max_seqlen_in_batch,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
200 |
+
"""
|
201 |
+
Arguments:
|
202 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
203 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
204 |
+
batch: int, batch size for the padded sequence.
|
205 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
206 |
+
Return:
|
207 |
+
hidden_states: (batch, seqlen, ...)
|
208 |
+
"""
|
209 |
+
dim = hidden_states.shape[-1]
|
210 |
+
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
211 |
+
# output[indices] = hidden_states
|
212 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
213 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
bigcode.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
|
8 |
+
|
9 |
+
|
10 |
+
def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
11 |
+
"""
|
12 |
+
Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
|
13 |
+
"""
|
14 |
+
|
15 |
+
# Word embedding and position embedding
|
16 |
+
def key_mapping_pos_emb(key):
|
17 |
+
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
18 |
+
|
19 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
20 |
+
word_embeddings = state_dict.pop("transformer.wte.weight")
|
21 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
22 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
23 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
24 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
25 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
26 |
+
)
|
27 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
28 |
+
|
29 |
+
# LayerNorm
|
30 |
+
def key_mapping_ln(key):
|
31 |
+
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
32 |
+
key = re.sub(
|
33 |
+
r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
|
34 |
+
r"transformer.layers.\1.norm\2.\3",
|
35 |
+
key,
|
36 |
+
)
|
37 |
+
return key
|
38 |
+
|
39 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
40 |
+
|
41 |
+
def key_mapping_mlp(key):
|
42 |
+
key = re.sub(
|
43 |
+
r"^transformer.h.(\d+).mlp.c_fc.weight",
|
44 |
+
r"transformer.layers.\1.mlp.fc1.weight",
|
45 |
+
key,
|
46 |
+
)
|
47 |
+
key = re.sub(
|
48 |
+
r"^transformer.h.(\d+).mlp.c_proj.weight",
|
49 |
+
r"transformer.layers.\1.mlp.fc2.weight",
|
50 |
+
key,
|
51 |
+
)
|
52 |
+
key = re.sub(
|
53 |
+
r"^transformer.h.(\d+).mlp.c_fc.bias",
|
54 |
+
r"transformer.layers.\1.mlp.fc1.bias",
|
55 |
+
key,
|
56 |
+
)
|
57 |
+
key = re.sub(
|
58 |
+
r"^transformer.h.(\d+).mlp.c_proj.bias",
|
59 |
+
r"transformer.layers.\1.mlp.fc2.bias",
|
60 |
+
key,
|
61 |
+
)
|
62 |
+
return key
|
63 |
+
|
64 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
65 |
+
|
66 |
+
# TODO: add support for multi-head attention
|
67 |
+
assert config.multi_query, "Only multi-query attention is supported"
|
68 |
+
|
69 |
+
# Attention
|
70 |
+
for d in range(config.num_hidden_layers):
|
71 |
+
embed_dim = config.n_embd
|
72 |
+
head_dim = embed_dim // config.n_head
|
73 |
+
|
74 |
+
c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
75 |
+
# with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
|
76 |
+
# see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
|
77 |
+
# see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
|
78 |
+
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
79 |
+
q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
|
80 |
+
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
81 |
+
k = torch.tile(k, (config.n_head, 1))
|
82 |
+
v = torch.tile(v, (config.n_head, 1))
|
83 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
|
84 |
+
|
85 |
+
# same deal with the bias
|
86 |
+
c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
|
87 |
+
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
88 |
+
q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
|
89 |
+
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
90 |
+
k = torch.tile(k, (config.n_head,))
|
91 |
+
v = torch.tile(v, (config.n_head,))
|
92 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
|
93 |
+
|
94 |
+
def key_mapping_attn(key):
|
95 |
+
key = re.sub(
|
96 |
+
r"^transformer.h.(\d+).attn.c_proj.weight",
|
97 |
+
r"transformer.layers.\1.mixer.out_proj.weight",
|
98 |
+
key,
|
99 |
+
)
|
100 |
+
key = re.sub(
|
101 |
+
r"^transformer.h.(\d+).attn.c_proj.bias",
|
102 |
+
r"transformer.layers.\1.mixer.out_proj.bias",
|
103 |
+
key,
|
104 |
+
)
|
105 |
+
return key
|
106 |
+
|
107 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
108 |
+
|
109 |
+
return state_dict
|
110 |
+
|
111 |
+
|
112 |
+
def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
113 |
+
"""
|
114 |
+
Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
|
115 |
+
|
116 |
+
This function is meant to be the inverse of remap_state_dict_hf_bigcode.
|
117 |
+
"""
|
118 |
+
|
119 |
+
# Word embedding and position embeddings
|
120 |
+
def inv_key_mapping_pos_emb(key):
|
121 |
+
return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
|
122 |
+
|
123 |
+
state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
124 |
+
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
125 |
+
|
126 |
+
word_embeddings = word_embeddings[:, : config.vocab_size]
|
127 |
+
state_dict["transformer.wte.weight"] = word_embeddings
|
128 |
+
state_dict["lm_head.weight"] = word_embeddings
|
129 |
+
|
130 |
+
# LayerNorm
|
131 |
+
def inv_key_mapping_ln(key):
|
132 |
+
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
133 |
+
key = re.sub(
|
134 |
+
r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
|
135 |
+
r"transformer.h.\1.ln_\2.\3",
|
136 |
+
key,
|
137 |
+
)
|
138 |
+
return key
|
139 |
+
|
140 |
+
state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
|
141 |
+
|
142 |
+
# MLPs
|
143 |
+
def inv_key_mapping_mlp(key):
|
144 |
+
key = re.sub(
|
145 |
+
r"^transformer.layers.(\d+).mlp.fc1.weight",
|
146 |
+
r"transformer.h.\1.mlp.c_fc.weight",
|
147 |
+
key,
|
148 |
+
)
|
149 |
+
key = re.sub(
|
150 |
+
r"^transformer.layers.(\d+).mlp.fc2.weight",
|
151 |
+
r"transformer.h.\1.mlp.c_proj.weight",
|
152 |
+
key,
|
153 |
+
)
|
154 |
+
key = re.sub(
|
155 |
+
r"^transformer.layers.(\d+).mlp.fc1.bias",
|
156 |
+
r"transformer.h.\1.mlp.c_fc.bias",
|
157 |
+
key,
|
158 |
+
)
|
159 |
+
key = re.sub(
|
160 |
+
r"^transformer.layers.(\d+).mlp.fc2.bias",
|
161 |
+
r"transformer.h.\1.mlp.c_proj.bias",
|
162 |
+
key,
|
163 |
+
)
|
164 |
+
return key
|
165 |
+
|
166 |
+
state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
|
167 |
+
|
168 |
+
# Attention
|
169 |
+
for d in range(config.num_hidden_layers):
|
170 |
+
embed_dim = config.n_embd
|
171 |
+
head_dim = embed_dim // config.n_head
|
172 |
+
|
173 |
+
Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
|
174 |
+
q, k, v = torch.split(
|
175 |
+
Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
176 |
+
)
|
177 |
+
c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
178 |
+
state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
|
179 |
+
|
180 |
+
# Same deal with the bias
|
181 |
+
Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
|
182 |
+
q, k, v = torch.split(
|
183 |
+
Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
184 |
+
)
|
185 |
+
c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
186 |
+
state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
|
187 |
+
|
188 |
+
def inv_key_mapping_attn(key):
|
189 |
+
key = re.sub(
|
190 |
+
r"^transformer.layers.(\d+).mixer.out_proj.weight",
|
191 |
+
r"transformer.h.\1.attn.c_proj.weight",
|
192 |
+
key,
|
193 |
+
)
|
194 |
+
key = re.sub(
|
195 |
+
r"^transformer.layers.(\d+).mixer.out_proj.bias",
|
196 |
+
r"transformer.h.\1.attn.c_proj.bias",
|
197 |
+
key,
|
198 |
+
)
|
199 |
+
return key
|
200 |
+
|
201 |
+
state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
|
202 |
+
|
203 |
+
return state_dict
|
204 |
+
|
205 |
+
|
206 |
+
def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
|
207 |
+
return GPT2Config(
|
208 |
+
activation_function=bigcode_config.activation_function,
|
209 |
+
attn_pdrop=bigcode_config.attn_pdrop,
|
210 |
+
bos_token_id=bigcode_config.bos_token_id,
|
211 |
+
embd_pdrop=bigcode_config.embd_pdrop,
|
212 |
+
eos_token_id=bigcode_config.eos_token_id,
|
213 |
+
initializer_range=bigcode_config.initializer_range,
|
214 |
+
layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
|
215 |
+
max_batch_size=bigcode_config.max_batch_size,
|
216 |
+
max_sequence_length=bigcode_config.max_sequence_length,
|
217 |
+
model_type=bigcode_config.model_type,
|
218 |
+
multi_query=bigcode_config.multi_query,
|
219 |
+
n_embd=bigcode_config.n_embd,
|
220 |
+
n_head=bigcode_config.n_head,
|
221 |
+
n_inner=bigcode_config.n_inner,
|
222 |
+
n_layer=bigcode_config.n_layer,
|
223 |
+
n_positions=bigcode_config.n_positions,
|
224 |
+
resid_pdrop=bigcode_config.resid_pdrop,
|
225 |
+
scale_attn_weights=bigcode_config.scale_attn_weights,
|
226 |
+
summary_activation=bigcode_config.summary_activation,
|
227 |
+
summary_first_dropout=bigcode_config.summary_first_dropout,
|
228 |
+
summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
|
229 |
+
summary_type=bigcode_config.summary_type,
|
230 |
+
summary_use_proj=bigcode_config.summary_use_proj,
|
231 |
+
use_cache=bigcode_config.use_cache,
|
232 |
+
vocab_size=bigcode_config.vocab_size,
|
233 |
+
)
|
block.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
from torchvision.ops import StochasticDepth
|
11 |
+
|
12 |
+
from flash_attn.modules.mha import MHA
|
13 |
+
from flash_attn.modules.mlp import Mlp
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
17 |
+
except ImportError:
|
18 |
+
layer_norm_fn, RMSNorm = None, None
|
19 |
+
|
20 |
+
|
21 |
+
class Block(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
dim,
|
25 |
+
mixer_cls=None,
|
26 |
+
mlp_cls=None,
|
27 |
+
norm_cls=nn.LayerNorm,
|
28 |
+
dropout_cls=nn.Dropout,
|
29 |
+
prenorm=True,
|
30 |
+
resid_dropout1=0.0,
|
31 |
+
resid_dropout2=0.0,
|
32 |
+
drop_path1=0.0,
|
33 |
+
drop_path2=0.0,
|
34 |
+
fused_dropout_add_ln=False,
|
35 |
+
return_residual=False,
|
36 |
+
residual_in_fp32=False,
|
37 |
+
sequence_parallel=False,
|
38 |
+
mark_shared_params=False,
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
For prenorm=True, this Block has a slightly different structure compared to a regular
|
42 |
+
prenorm Transformer block.
|
43 |
+
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
44 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
45 |
+
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
46 |
+
the hidden_states (output of the MLP) and the residual.
|
47 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
48 |
+
The residual needs to be provided (except for the very first block).
|
49 |
+
|
50 |
+
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
51 |
+
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
52 |
+
|
53 |
+
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
54 |
+
This is for performance reason: for post-norm architecture, returning the input allows us
|
55 |
+
to fuse the backward of nn.Linear with the residual connection.
|
56 |
+
"""
|
57 |
+
super().__init__()
|
58 |
+
self.prenorm = prenorm
|
59 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
60 |
+
self.return_residual = return_residual
|
61 |
+
self.residual_in_fp32 = residual_in_fp32
|
62 |
+
if self.residual_in_fp32:
|
63 |
+
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
64 |
+
if mixer_cls is None:
|
65 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
66 |
+
if mlp_cls is None:
|
67 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
68 |
+
self.mixer = mixer_cls(dim)
|
69 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
70 |
+
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
71 |
+
self.norm1 = norm_cls(dim)
|
72 |
+
self.mlp = mlp_cls(dim)
|
73 |
+
if not isinstance(self.mlp, nn.Identity):
|
74 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
75 |
+
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
76 |
+
self.norm2 = norm_cls(dim)
|
77 |
+
|
78 |
+
if self.fused_dropout_add_ln:
|
79 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
80 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
81 |
+
self.dropout1, nn.Dropout
|
82 |
+
)
|
83 |
+
|
84 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
85 |
+
# then the input to each worker in the tensor parallel group will be different.
|
86 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
87 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
88 |
+
# and only use sequence_parallel=False during inference.
|
89 |
+
|
90 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
91 |
+
if sequence_parallel:
|
92 |
+
for p in self.norm1.parameters():
|
93 |
+
p._sequence_parallel = True
|
94 |
+
if hasattr(self, "norm2"):
|
95 |
+
for p in self.norm2.parameters():
|
96 |
+
p._sequence_parallel = True
|
97 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
98 |
+
if mark_shared_params:
|
99 |
+
for p in self.norm1.parameters():
|
100 |
+
p._shared_params = True
|
101 |
+
if hasattr(self, "norm2"):
|
102 |
+
for p in self.norm2.parameters():
|
103 |
+
p._shared_params = True
|
104 |
+
|
105 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
106 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
107 |
+
|
108 |
+
def forward(
|
109 |
+
self,
|
110 |
+
hidden_states: Tensor,
|
111 |
+
residual: Optional[Tensor] = None,
|
112 |
+
mixer_subset=None,
|
113 |
+
mixer_kwargs=None,
|
114 |
+
):
|
115 |
+
r"""Pass the input through the encoder layer.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
hidden_states: the sequence to the encoder layer (required).
|
119 |
+
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
120 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
121 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
122 |
+
about the CLS token in the last layer.
|
123 |
+
"""
|
124 |
+
if self.prenorm:
|
125 |
+
if not self.fused_dropout_add_ln:
|
126 |
+
dropped = self.drop_path1(self.dropout1(hidden_states))
|
127 |
+
residual = (dropped + residual) if residual is not None else dropped
|
128 |
+
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
129 |
+
if self.residual_in_fp32:
|
130 |
+
residual = residual.to(torch.float32)
|
131 |
+
else:
|
132 |
+
if self.drop_path1.p == 0 or not self.training:
|
133 |
+
rowscale1 = None
|
134 |
+
else:
|
135 |
+
rowscale1 = self.drop_path1(
|
136 |
+
torch.ones(
|
137 |
+
hidden_states.shape[:-1],
|
138 |
+
device=hidden_states.device,
|
139 |
+
dtype=hidden_states.dtype,
|
140 |
+
)
|
141 |
+
)
|
142 |
+
hidden_states, residual = layer_norm_fn(
|
143 |
+
hidden_states,
|
144 |
+
self.norm1.weight,
|
145 |
+
self.norm1.bias,
|
146 |
+
residual=residual,
|
147 |
+
eps=self.norm1.eps,
|
148 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
149 |
+
rowscale=rowscale1,
|
150 |
+
prenorm=True,
|
151 |
+
residual_in_fp32=self.residual_in_fp32,
|
152 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
153 |
+
)
|
154 |
+
if mixer_kwargs is None:
|
155 |
+
mixer_kwargs = {}
|
156 |
+
if mixer_subset is not None:
|
157 |
+
mixer_kwargs["mixer_subset"] = mixer_subset
|
158 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
159 |
+
if mixer_subset is not None:
|
160 |
+
residual = residual[:, mixer_subset]
|
161 |
+
if not isinstance(self.mlp, nn.Identity):
|
162 |
+
if not self.fused_dropout_add_ln:
|
163 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
164 |
+
residual = (dropped + residual) if residual is not None else dropped
|
165 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
166 |
+
if self.residual_in_fp32:
|
167 |
+
residual = residual.to(torch.float32)
|
168 |
+
else:
|
169 |
+
if self.drop_path2.p == 0 or not self.training:
|
170 |
+
rowscale2 = None
|
171 |
+
else:
|
172 |
+
rowscale2 = self.drop_path2(
|
173 |
+
torch.ones(
|
174 |
+
hidden_states.shape[:-1],
|
175 |
+
device=hidden_states.device,
|
176 |
+
dtype=hidden_states.dtype,
|
177 |
+
)
|
178 |
+
)
|
179 |
+
hidden_states, residual = layer_norm_fn(
|
180 |
+
hidden_states,
|
181 |
+
self.norm2.weight,
|
182 |
+
self.norm2.bias,
|
183 |
+
residual=residual,
|
184 |
+
eps=self.norm2.eps,
|
185 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
186 |
+
rowscale=rowscale2,
|
187 |
+
prenorm=True,
|
188 |
+
residual_in_fp32=self.residual_in_fp32,
|
189 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
190 |
+
)
|
191 |
+
hidden_states = self.mlp(hidden_states)
|
192 |
+
return hidden_states, residual
|
193 |
+
else:
|
194 |
+
assert residual is None
|
195 |
+
mixer_out = self.mixer(
|
196 |
+
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
|
197 |
+
)
|
198 |
+
if self.return_residual: # mixer out is actually a pair here
|
199 |
+
mixer_out, hidden_states = mixer_out
|
200 |
+
if not self.fused_dropout_add_ln:
|
201 |
+
hidden_states = self.norm1(
|
202 |
+
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
|
203 |
+
dtype=self.norm1.weight.dtype
|
204 |
+
)
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
if self.drop_path1.p == 0 or not self.training:
|
208 |
+
rowscale1 = None
|
209 |
+
else:
|
210 |
+
rowscale1 = self.drop_path1(
|
211 |
+
torch.ones(
|
212 |
+
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
|
213 |
+
)
|
214 |
+
)
|
215 |
+
hidden_states = layer_norm_fn(
|
216 |
+
mixer_out,
|
217 |
+
self.norm1.weight,
|
218 |
+
self.norm1.bias,
|
219 |
+
residual=hidden_states,
|
220 |
+
eps=self.norm1.eps,
|
221 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
222 |
+
rowscale=rowscale1,
|
223 |
+
prenorm=False,
|
224 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
225 |
+
)
|
226 |
+
if not isinstance(self.mlp, nn.Identity):
|
227 |
+
mlp_out = self.mlp(hidden_states)
|
228 |
+
if self.return_residual: # mlp out is actually a pair here
|
229 |
+
mlp_out, hidden_states = mlp_out
|
230 |
+
if not self.fused_dropout_add_ln:
|
231 |
+
hidden_states = self.norm2(
|
232 |
+
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
|
233 |
+
dtype=self.norm2.weight.dtype
|
234 |
+
)
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
if self.drop_path2.p == 0 or not self.training:
|
238 |
+
rowscale2 = None
|
239 |
+
else:
|
240 |
+
rowscale2 = self.drop_path2(
|
241 |
+
torch.ones(
|
242 |
+
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
|
243 |
+
)
|
244 |
+
)
|
245 |
+
hidden_states = layer_norm_fn(
|
246 |
+
mlp_out,
|
247 |
+
self.norm2.weight,
|
248 |
+
self.norm2.bias,
|
249 |
+
residual=hidden_states,
|
250 |
+
eps=self.norm2.eps,
|
251 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
252 |
+
rowscale=rowscale2,
|
253 |
+
prenorm=False,
|
254 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
255 |
+
)
|
256 |
+
return hidden_states
|
257 |
+
|
258 |
+
|
259 |
+
class ParallelBlock(nn.Module):
|
260 |
+
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
|
261 |
+
and PaLM.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
dim,
|
267 |
+
mixer_cls=None,
|
268 |
+
mlp_cls=None,
|
269 |
+
norm_cls=nn.LayerNorm,
|
270 |
+
dropout_cls=nn.Dropout,
|
271 |
+
resid_dropout1=0.0,
|
272 |
+
resid_dropout2=0.0,
|
273 |
+
tied_norm=False,
|
274 |
+
fused_dropout_add_ln=False,
|
275 |
+
residual_in_fp32=False,
|
276 |
+
sequence_parallel=False,
|
277 |
+
mark_shared_params=False,
|
278 |
+
):
|
279 |
+
"""
|
280 |
+
This Block has a slightly different structure compared to a regular
|
281 |
+
prenorm Transformer block.
|
282 |
+
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
|
283 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
284 |
+
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
|
285 |
+
the hidden_states (output1 of the MHA / MLP) and the residual.
|
286 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
287 |
+
The residual needs to be provided (except for the very first block).
|
288 |
+
"""
|
289 |
+
super().__init__()
|
290 |
+
self.tied_norm = tied_norm
|
291 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
292 |
+
self.residual_in_fp32 = residual_in_fp32
|
293 |
+
if mixer_cls is None:
|
294 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
295 |
+
if mlp_cls is None:
|
296 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
297 |
+
self.mixer = mixer_cls(dim)
|
298 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
299 |
+
self.norm1 = norm_cls(dim)
|
300 |
+
self.mlp = mlp_cls(dim)
|
301 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
302 |
+
if not self.tied_norm:
|
303 |
+
self.norm2 = norm_cls(dim)
|
304 |
+
|
305 |
+
if self.fused_dropout_add_ln:
|
306 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
307 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
308 |
+
self.dropout1, nn.Dropout
|
309 |
+
)
|
310 |
+
|
311 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
312 |
+
# then the input to each worker in the tensor parallel group will be different.
|
313 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
314 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
315 |
+
# and only use sequence_parallel=False during inference.
|
316 |
+
|
317 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
318 |
+
if sequence_parallel:
|
319 |
+
for p in self.norm1.parameters():
|
320 |
+
p._sequence_parallel = True
|
321 |
+
if hasattr(self, "norm2"):
|
322 |
+
for p in self.norm2.parameters():
|
323 |
+
p._sequence_parallel = True
|
324 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
325 |
+
if mark_shared_params:
|
326 |
+
for p in self.norm1.parameters():
|
327 |
+
p._shared_params = True
|
328 |
+
if hasattr(self, "norm2"):
|
329 |
+
for p in self.norm2.parameters():
|
330 |
+
p._shared_params = True
|
331 |
+
|
332 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
333 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
334 |
+
|
335 |
+
def forward(
|
336 |
+
self,
|
337 |
+
hidden_states1: Tensor,
|
338 |
+
hidden_states2: Optional[Tensor] = None,
|
339 |
+
residual: Optional[Tensor] = None,
|
340 |
+
mixer_kwargs=None,
|
341 |
+
):
|
342 |
+
r"""Pass the input through the encoder layer.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
hidden_states1: the output of the previous attention (mixer) or embedding layer.
|
346 |
+
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
347 |
+
residual.
|
348 |
+
"""
|
349 |
+
# TODO: Ideally we should only do the allgather / allreduce once for
|
350 |
+
# the Linear to MLP & Attention
|
351 |
+
if not self.fused_dropout_add_ln:
|
352 |
+
dropped1 = self.dropout1(hidden_states1)
|
353 |
+
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
354 |
+
if hidden_states2 is not None:
|
355 |
+
dropped2 = self.dropout2(hidden_states2)
|
356 |
+
residual = (
|
357 |
+
(residual + dropped1 + dropped2)
|
358 |
+
if residual is not None
|
359 |
+
else dropped1 + dropped2
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
residual = (residual + dropped1) if residual is not None else dropped1
|
363 |
+
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
364 |
+
hidden_states2 = (
|
365 |
+
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
366 |
+
if not self.tied_norm
|
367 |
+
else hidden_states1
|
368 |
+
)
|
369 |
+
if self.residual_in_fp32:
|
370 |
+
residual = residual.to(torch.float32)
|
371 |
+
else:
|
372 |
+
weight2, bias2 = (
|
373 |
+
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
|
374 |
+
)
|
375 |
+
hidden_states1, *rest, residual = layer_norm_fn(
|
376 |
+
hidden_states1,
|
377 |
+
self.norm1.weight,
|
378 |
+
self.norm1.bias,
|
379 |
+
residual=residual,
|
380 |
+
x1=hidden_states2,
|
381 |
+
weight1=weight2,
|
382 |
+
bias1=bias2,
|
383 |
+
eps=self.norm1.eps,
|
384 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
385 |
+
prenorm=True,
|
386 |
+
residual_in_fp32=self.residual_in_fp32,
|
387 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
388 |
+
)
|
389 |
+
if self.tied_norm:
|
390 |
+
hidden_states2 = hidden_states1
|
391 |
+
else:
|
392 |
+
hidden_states2, = rest
|
393 |
+
if mixer_kwargs is None:
|
394 |
+
mixer_kwargs = {}
|
395 |
+
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
396 |
+
hidden_states2 = self.mlp(hidden_states2)
|
397 |
+
return hidden_states1, hidden_states2, residual
|
block_info.h
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
namespace flash {
|
8 |
+
|
9 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
10 |
+
|
11 |
+
template<bool Varlen=true>
|
12 |
+
struct BlockInfo {
|
13 |
+
|
14 |
+
template<typename Params>
|
15 |
+
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
16 |
+
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
17 |
+
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
18 |
+
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
19 |
+
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
20 |
+
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
21 |
+
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
22 |
+
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
23 |
+
{
|
24 |
+
}
|
25 |
+
|
26 |
+
template <typename index_t>
|
27 |
+
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
28 |
+
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
29 |
+
}
|
30 |
+
|
31 |
+
template <typename index_t>
|
32 |
+
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
33 |
+
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
34 |
+
}
|
35 |
+
|
36 |
+
const int sum_s_q;
|
37 |
+
const int sum_s_k;
|
38 |
+
const int actual_seqlen_q;
|
39 |
+
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
40 |
+
const int seqlen_k_cache;
|
41 |
+
const int actual_seqlen_k;
|
42 |
+
};
|
43 |
+
|
44 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
45 |
+
|
46 |
+
} // namespace flash
|
btlm.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
from transformers import GPT2Config, AutoConfig, PretrainedConfig
|
15 |
+
|
16 |
+
|
17 |
+
def remap_state_dict_hf_btlm(state_dict, config):
|
18 |
+
# Word embedding and position embedding
|
19 |
+
def key_mapping_pos_emb(key):
|
20 |
+
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
21 |
+
|
22 |
+
if "transformer.wpe.weight" in state_dict:
|
23 |
+
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
24 |
+
word_embeddings = state_dict.pop("transformer.wte.weight")
|
25 |
+
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
26 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
27 |
+
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
28 |
+
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
29 |
+
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
30 |
+
)
|
31 |
+
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
32 |
+
|
33 |
+
# LayerNorm
|
34 |
+
def key_mapping_ln(key):
|
35 |
+
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
36 |
+
key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
|
37 |
+
return key
|
38 |
+
|
39 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
40 |
+
|
41 |
+
# MLP
|
42 |
+
for d in range(config.num_hidden_layers):
|
43 |
+
W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight")
|
44 |
+
W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight")
|
45 |
+
state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0)
|
46 |
+
b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias")
|
47 |
+
b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias")
|
48 |
+
state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0)
|
49 |
+
W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight")
|
50 |
+
state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
|
51 |
+
|
52 |
+
def key_mapping_mlp(key):
|
53 |
+
key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
|
54 |
+
return key
|
55 |
+
|
56 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
57 |
+
|
58 |
+
# Attention
|
59 |
+
for d in range(config.num_hidden_layers):
|
60 |
+
Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
61 |
+
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
|
62 |
+
Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight")
|
63 |
+
state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
|
64 |
+
state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes
|
65 |
+
|
66 |
+
def key_mapping_attn(key):
|
67 |
+
key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
|
68 |
+
key = re.sub(
|
69 |
+
r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
|
70 |
+
)
|
71 |
+
return key
|
72 |
+
|
73 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
74 |
+
|
75 |
+
return state_dict
|
76 |
+
|
77 |
+
|
78 |
+
def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:
|
79 |
+
return GPT2Config(
|
80 |
+
vocab_size=btlm_config.vocab_size,
|
81 |
+
n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions,
|
82 |
+
n_embd=btlm_config.hidden_size,
|
83 |
+
n_layer=btlm_config.num_hidden_layers,
|
84 |
+
n_head=btlm_config.num_attention_heads,
|
85 |
+
n_inner=btlm_config.n_inner,
|
86 |
+
activation_function=btlm_config.activation_function,
|
87 |
+
resid_pdrop=btlm_config.resid_pdrop,
|
88 |
+
embd_pdrop=btlm_config.embd_pdrop,
|
89 |
+
attn_pdrop=btlm_config.attn_pdrop,
|
90 |
+
layer_norm_epsilon=btlm_config.layer_norm_epsilon,
|
91 |
+
initializer_range=btlm_config.initializer_range,
|
92 |
+
bos_token_id=btlm_config.bos_token_id,
|
93 |
+
eos_token_id=btlm_config.eos_token_id,
|
94 |
+
# These are new arguments not in the original GPT2Config
|
95 |
+
use_alibi=btlm_config.position_embedding_type == "alibi",
|
96 |
+
use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn
|
97 |
+
mup_width_scale=btlm_config.mup_width_scale,
|
98 |
+
mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,
|
99 |
+
mup_output_multiplier=btlm_config.mup_output_alpha,
|
100 |
+
mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,
|
101 |
+
mlp_multiple_of=1,
|
102 |
+
)
|
causality-monitor.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
causality-monitor:
|
2 |
+
_target_: src.callbacks.causality_monitor.CausalityMonitor
|
comet.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.comet.ml
|
2 |
+
|
3 |
+
comet:
|
4 |
+
_target_: pytorch_lightning.loggers.comet.CometLogger
|
5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
6 |
+
project_name: "template-tests"
|
7 |
+
experiment_name: ${name}
|
config.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# specify here default training configuration
|
4 |
+
defaults:
|
5 |
+
- _self_
|
6 |
+
- trainer: default
|
7 |
+
- optimizer: adamw
|
8 |
+
- scheduler: null
|
9 |
+
- task: sequence-model
|
10 |
+
- model: null
|
11 |
+
- datamodule: null
|
12 |
+
- callbacks: default # set this to null if you don't want to use callbacks
|
13 |
+
- metrics: null
|
14 |
+
- logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
|
15 |
+
|
16 |
+
- mode: default
|
17 |
+
|
18 |
+
- experiment: null
|
19 |
+
- hparams_search: null
|
20 |
+
|
21 |
+
# enable color logging
|
22 |
+
- override hydra/hydra_logging: colorlog
|
23 |
+
- override hydra/job_logging: colorlog
|
24 |
+
|
25 |
+
# path to original working directory
|
26 |
+
# hydra hijacks working directory by changing it to the current log directory,
|
27 |
+
# so it's useful to have this path as a special variable
|
28 |
+
# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
29 |
+
work_dir: ${hydra:runtime.cwd}
|
30 |
+
|
31 |
+
# path to folder with data
|
32 |
+
data_dir: ${work_dir}/data/
|
33 |
+
|
34 |
+
# pretty print config at the start of the run using Rich library
|
35 |
+
print_config: True
|
36 |
+
|
37 |
+
# disable python warnings if they annoy you
|
38 |
+
ignore_warnings: True
|
39 |
+
|
40 |
+
# check performance on test set, using the best model achieved during training
|
41 |
+
# lightning chooses best model based on metric specified in checkpoint callback
|
42 |
+
test_after_training: True
|
43 |
+
|
44 |
+
resume: False
|
45 |
+
|
46 |
+
# seed for random number generators in pytorch, numpy and python.random
|
47 |
+
seed: null
|
48 |
+
|
49 |
+
# name of the run, accessed by loggers
|
50 |
+
name: null
|
cosine-warmup-timm.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# @package train.scheduler
|
2 |
+
_target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler
|
cosine-warmup.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# @package train.scheduler
|
2 |
+
_target_: transformers.get_cosine_schedule_with_warmup
|
cross_entropy.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Tuple, Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
11 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
12 |
+
# version of PyTorch. The following 2 lines are for backward compatibility with
|
13 |
+
# older PyTorch.
|
14 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
15 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
16 |
+
|
17 |
+
|
18 |
+
@triton.heuristics(
|
19 |
+
{
|
20 |
+
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
21 |
+
}
|
22 |
+
)
|
23 |
+
@triton.jit
|
24 |
+
def cross_entropy_fwd_kernel(
|
25 |
+
loss_ptr, # data ptrs
|
26 |
+
lse_ptr,
|
27 |
+
z_loss_ptr,
|
28 |
+
logits_ptr,
|
29 |
+
labels_ptr,
|
30 |
+
smoothing,
|
31 |
+
logit_scale,
|
32 |
+
lse_square_scale,
|
33 |
+
ignore_index,
|
34 |
+
total_classes,
|
35 |
+
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
36 |
+
n_cols, # shapes
|
37 |
+
n_rows,
|
38 |
+
logits_row_stride, # strides
|
39 |
+
BLOCK_SIZE: tl.constexpr,
|
40 |
+
HAS_SMOOTHING: tl.constexpr,
|
41 |
+
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
|
42 |
+
SPLIT: tl.constexpr,
|
43 |
+
):
|
44 |
+
row_idx = tl.program_id(0)
|
45 |
+
col_block_idx = tl.program_id(1)
|
46 |
+
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
47 |
+
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
48 |
+
label_idx = tl.load(labels_ptr + row_idx)
|
49 |
+
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
50 |
+
tl.float32
|
51 |
+
) * logit_scale
|
52 |
+
max_logits = tl.max(logits, 0)
|
53 |
+
if HAS_SMOOTHING:
|
54 |
+
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
55 |
+
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
56 |
+
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
57 |
+
if label_idx == ignore_index:
|
58 |
+
loss = 0.0
|
59 |
+
z_loss = 0.0
|
60 |
+
else:
|
61 |
+
label_idx -= class_start_idx
|
62 |
+
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
63 |
+
n_cols, (col_block_idx + 1) * BLOCK_SIZE
|
64 |
+
):
|
65 |
+
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
|
66 |
+
if HAS_SMOOTHING:
|
67 |
+
loss = (
|
68 |
+
(lse if not SPLIT else 0.0)
|
69 |
+
- smoothing * sum_logits / total_classes
|
70 |
+
- (1 - smoothing) * logits_label
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
loss = (lse if not SPLIT else 0.0) - logits_label
|
74 |
+
else:
|
75 |
+
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
|
76 |
+
if HAS_SMOOTHING:
|
77 |
+
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
|
78 |
+
else:
|
79 |
+
loss = 0.0
|
80 |
+
if not SPLIT:
|
81 |
+
z_loss = lse_square_scale * lse * lse
|
82 |
+
loss += z_loss
|
83 |
+
else:
|
84 |
+
z_loss = 0.0
|
85 |
+
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
|
86 |
+
if not SPLIT:
|
87 |
+
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
|
88 |
+
|
89 |
+
|
90 |
+
@triton.heuristics(
|
91 |
+
{
|
92 |
+
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
93 |
+
}
|
94 |
+
)
|
95 |
+
@triton.jit
|
96 |
+
def cross_entropy_bwd_kernel(
|
97 |
+
dlogits_ptr, # data ptrs
|
98 |
+
dloss_ptr,
|
99 |
+
logits_ptr,
|
100 |
+
lse_ptr,
|
101 |
+
labels_ptr,
|
102 |
+
smoothing,
|
103 |
+
logit_scale,
|
104 |
+
lse_square_scale,
|
105 |
+
ignore_index,
|
106 |
+
total_classes,
|
107 |
+
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
108 |
+
n_cols, # shapes
|
109 |
+
logits_row_stride, # strides
|
110 |
+
dlogits_row_stride,
|
111 |
+
dloss_row_stride,
|
112 |
+
BLOCK_SIZE: tl.constexpr,
|
113 |
+
HAS_SMOOTHING: tl.constexpr,
|
114 |
+
):
|
115 |
+
row_idx = tl.program_id(0)
|
116 |
+
col_block_idx = tl.program_id(1)
|
117 |
+
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
118 |
+
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
|
119 |
+
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
120 |
+
label_idx = tl.load(labels_ptr + row_idx)
|
121 |
+
if label_idx != ignore_index:
|
122 |
+
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
123 |
+
else:
|
124 |
+
dloss = 0.0
|
125 |
+
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
126 |
+
tl.float32
|
127 |
+
) * logit_scale
|
128 |
+
lse = tl.load(lse_ptr + row_idx)
|
129 |
+
probs = tl.exp(logits - lse)
|
130 |
+
probs += 2.0 * lse_square_scale * lse * probs
|
131 |
+
label_idx -= class_start_idx
|
132 |
+
if HAS_SMOOTHING:
|
133 |
+
smooth_positive = 1.0 - smoothing
|
134 |
+
smooth_negative = smoothing / total_classes
|
135 |
+
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
|
136 |
+
else:
|
137 |
+
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
138 |
+
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
|
139 |
+
|
140 |
+
|
141 |
+
class CrossEntropyLoss(torch.autograd.Function):
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def forward(
|
145 |
+
ctx,
|
146 |
+
logits,
|
147 |
+
labels,
|
148 |
+
smoothing=0.0,
|
149 |
+
logit_scale=1.0,
|
150 |
+
lse_square_scale=0.0,
|
151 |
+
ignore_index=-100,
|
152 |
+
inplace_backward=False,
|
153 |
+
process_group=None,
|
154 |
+
):
|
155 |
+
n_rows, n_cols = logits.shape
|
156 |
+
assert labels.shape == (n_rows,)
|
157 |
+
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
158 |
+
total_classes = world_size * n_cols
|
159 |
+
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
|
160 |
+
class_start_idx = rank * n_cols
|
161 |
+
|
162 |
+
if logits.stride(-1) != 1:
|
163 |
+
logits = logits.contiguous()
|
164 |
+
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
|
165 |
+
MAX_BLOCK_SIZE = 64 * 1024
|
166 |
+
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
|
167 |
+
num_warps = (
|
168 |
+
4
|
169 |
+
if BLOCK_SIZE < 2048
|
170 |
+
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
171 |
+
)
|
172 |
+
# We may split the lse computation across multiple blocks, then do a reduction
|
173 |
+
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
|
174 |
+
# where having just one thread block processing more than 64k elements is slow.
|
175 |
+
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
|
176 |
+
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
|
177 |
+
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
|
178 |
+
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
179 |
+
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
180 |
+
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
181 |
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
182 |
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
183 |
+
with torch.cuda.device(logits.device.index):
|
184 |
+
cross_entropy_fwd_kernel[(n_rows, n_splits)](
|
185 |
+
losses, # data ptrs
|
186 |
+
lse,
|
187 |
+
z_losses,
|
188 |
+
logits,
|
189 |
+
labels,
|
190 |
+
smoothing,
|
191 |
+
logit_scale,
|
192 |
+
lse_square_scale,
|
193 |
+
ignore_index,
|
194 |
+
total_classes,
|
195 |
+
class_start_idx,
|
196 |
+
n_cols, # shapes
|
197 |
+
n_rows,
|
198 |
+
logits.stride(0), # strides
|
199 |
+
BLOCK_SIZE=BLOCK_SIZE, # constants
|
200 |
+
num_warps=num_warps,
|
201 |
+
SPLIT=split,
|
202 |
+
)
|
203 |
+
|
204 |
+
if split:
|
205 |
+
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
206 |
+
# - predicted logit, and 0 otherwise.
|
207 |
+
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
208 |
+
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
|
209 |
+
# For labels not in the vocab of this partition, losses contains
|
210 |
+
# -0.1 * sum logit / total_classes.
|
211 |
+
if n_splits > 1:
|
212 |
+
lse = torch.logsumexp(lse, dim=0)
|
213 |
+
losses = losses.sum(dim=0)
|
214 |
+
if world_size > 1:
|
215 |
+
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
|
216 |
+
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
|
217 |
+
handle_losses = torch.distributed.all_reduce(
|
218 |
+
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
219 |
+
)
|
220 |
+
lse = torch.logsumexp(lse_allgather, dim=0)
|
221 |
+
handle_losses.wait()
|
222 |
+
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
|
223 |
+
# we just have to add the (global) lse.
|
224 |
+
# If there's smoothing=0.1, the total losses are
|
225 |
+
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
|
226 |
+
# Again, we just have to add the (global) lse.
|
227 |
+
losses += lse
|
228 |
+
if lse_square_scale != 0.0:
|
229 |
+
z_losses = lse_square_scale * lse.square()
|
230 |
+
z_losses.masked_fill_(labels == ignore_index, 0.0)
|
231 |
+
losses += z_losses
|
232 |
+
else:
|
233 |
+
z_losses = torch.zeros_like(losses)
|
234 |
+
losses.masked_fill_(labels == ignore_index, 0.0)
|
235 |
+
|
236 |
+
ctx.save_for_backward(logits, lse, labels)
|
237 |
+
ctx.mark_non_differentiable(z_losses)
|
238 |
+
ctx.smoothing = smoothing
|
239 |
+
ctx.logit_scale = logit_scale
|
240 |
+
ctx.lse_square_scale = lse_square_scale
|
241 |
+
ctx.ignore_index = ignore_index
|
242 |
+
ctx.total_classes = total_classes
|
243 |
+
ctx.class_start_idx = class_start_idx
|
244 |
+
ctx.inplace_backward = inplace_backward
|
245 |
+
|
246 |
+
return losses, z_losses
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def backward(ctx, grad_losses, grad_z_losses):
|
250 |
+
del grad_z_losses # z_losses are only for logging.
|
251 |
+
|
252 |
+
logits, lse, labels = ctx.saved_tensors
|
253 |
+
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
|
254 |
+
n_rows, n_cols = logits.shape
|
255 |
+
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
|
256 |
+
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
|
257 |
+
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
|
258 |
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
259 |
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
260 |
+
with torch.cuda.device(logits.device.index):
|
261 |
+
cross_entropy_bwd_kernel[grid](
|
262 |
+
dlogits, # data ptrs
|
263 |
+
grad_losses,
|
264 |
+
logits,
|
265 |
+
lse,
|
266 |
+
labels,
|
267 |
+
ctx.smoothing,
|
268 |
+
ctx.logit_scale,
|
269 |
+
ctx.lse_square_scale,
|
270 |
+
ctx.ignore_index,
|
271 |
+
ctx.total_classes,
|
272 |
+
ctx.class_start_idx,
|
273 |
+
n_cols, # shapes
|
274 |
+
logits.stride(0), # strides
|
275 |
+
dlogits.stride(0),
|
276 |
+
grad_losses.stride(0),
|
277 |
+
BLOCK_SIZE=BLOCK_SIZE, # constants
|
278 |
+
num_warps=num_warps,
|
279 |
+
)
|
280 |
+
return dlogits, None, None, None, None, None, None, None, None
|
281 |
+
|
282 |
+
def cross_entropy_loss(
|
283 |
+
logits: torch.Tensor,
|
284 |
+
labels: torch.Tensor,
|
285 |
+
label_smoothing: float = 0.0,
|
286 |
+
logit_scale: float = 1.0,
|
287 |
+
lse_square_scale: float = 0.0,
|
288 |
+
ignore_index=-100,
|
289 |
+
inplace_backward: bool = False,
|
290 |
+
process_group=None,
|
291 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
292 |
+
"""
|
293 |
+
Arguments:
|
294 |
+
logits: (batch, vocab_size)
|
295 |
+
labels: (batch,)
|
296 |
+
label_smoothing: float
|
297 |
+
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
298 |
+
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
299 |
+
This is also referred to as "z-loss".
|
300 |
+
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
|
301 |
+
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
302 |
+
This saves memory.
|
303 |
+
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
304 |
+
one part of the vocab. The loss will be aggregated across processes.
|
305 |
+
Returns:
|
306 |
+
losses: (batch,), float
|
307 |
+
z_losses: (batch,), float
|
308 |
+
"""
|
309 |
+
return CrossEntropyLoss.apply(
|
310 |
+
logits,
|
311 |
+
labels,
|
312 |
+
label_smoothing,
|
313 |
+
logit_scale,
|
314 |
+
lse_square_scale,
|
315 |
+
ignore_index,
|
316 |
+
inplace_backward,
|
317 |
+
process_group,
|
318 |
+
)
|
csv.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# csv logger built in lightning
|
2 |
+
|
3 |
+
csv:
|
4 |
+
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
|
5 |
+
save_dir: "."
|
6 |
+
name: "csv/"
|
7 |
+
version: ${name}
|
8 |
+
prefix: ""
|
cuda_bf16_fallbacks.cuh
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include <cuda_fp16.h>
|
23 |
+
|
24 |
+
namespace fastertransformer {
|
25 |
+
|
26 |
+
#ifdef ENABLE_BF16
|
27 |
+
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
28 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
29 |
+
float2 f_val;
|
30 |
+
f_val.x = __low2float(val);
|
31 |
+
f_val.y = __high2float(val);
|
32 |
+
return f_val;
|
33 |
+
#else
|
34 |
+
return __bfloat1622float2(val);
|
35 |
+
#endif
|
36 |
+
}
|
37 |
+
|
38 |
+
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
|
39 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
40 |
+
float2 f_val;
|
41 |
+
f_val.x = max(min(__low2float(val), 127.f), -128.f);
|
42 |
+
f_val.y = max(min(__high2float(val), 127.f), -128.f);
|
43 |
+
union { int8_t int8[2]; int16_t int16; };
|
44 |
+
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
|
45 |
+
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
|
46 |
+
return int16;
|
47 |
+
#else
|
48 |
+
val = __hmin2(val, make_bfloat162(127., 127.));
|
49 |
+
val = __hmax2(val, make_bfloat162(-128., -128.));
|
50 |
+
union { int8_t int8[2]; int16_t int16; };
|
51 |
+
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
|
52 |
+
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
|
53 |
+
return int16;
|
54 |
+
#endif
|
55 |
+
}
|
56 |
+
|
57 |
+
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
|
58 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
59 |
+
return __floats2bfloat162_rn(val.x, val.y);
|
60 |
+
#else
|
61 |
+
return __float22bfloat162_rn(val);
|
62 |
+
#endif
|
63 |
+
}
|
64 |
+
|
65 |
+
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
66 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
67 |
+
__nv_bfloat162 val2;
|
68 |
+
val2.x = val;
|
69 |
+
val2.y = val;
|
70 |
+
return val2;
|
71 |
+
#else
|
72 |
+
return __bfloat162bfloat162(val);
|
73 |
+
#endif
|
74 |
+
}
|
75 |
+
|
76 |
+
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
77 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
78 |
+
float fxl, fxh, fyl, fyh;
|
79 |
+
fxl = __low2float(x);
|
80 |
+
fxh = __high2float(x);
|
81 |
+
fyl = __low2float(y);
|
82 |
+
fyh = __high2float(y);
|
83 |
+
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
|
84 |
+
#else
|
85 |
+
return __hadd2(x, y);
|
86 |
+
#endif
|
87 |
+
}
|
88 |
+
|
89 |
+
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
|
90 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
91 |
+
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
|
92 |
+
#else
|
93 |
+
return __hadd(x, y);
|
94 |
+
#endif
|
95 |
+
}
|
96 |
+
|
97 |
+
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
98 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
99 |
+
float fxl, fxh, fyl, fyh;
|
100 |
+
fxl = __low2float(x);
|
101 |
+
fxh = __high2float(x);
|
102 |
+
fyl = __low2float(y);
|
103 |
+
fyh = __high2float(y);
|
104 |
+
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
|
105 |
+
#else
|
106 |
+
return __hsub2(x, y);
|
107 |
+
#endif
|
108 |
+
}
|
109 |
+
|
110 |
+
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
|
111 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
112 |
+
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
|
113 |
+
#else
|
114 |
+
return __hsub(x, y);
|
115 |
+
#endif
|
116 |
+
}
|
117 |
+
|
118 |
+
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
119 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
120 |
+
float fxl, fxh, fyl, fyh;
|
121 |
+
fxl = __low2float(x);
|
122 |
+
fxh = __high2float(x);
|
123 |
+
fyl = __low2float(y);
|
124 |
+
fyh = __high2float(y);
|
125 |
+
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
|
126 |
+
#else
|
127 |
+
return __hmul2(x, y);
|
128 |
+
#endif
|
129 |
+
}
|
130 |
+
|
131 |
+
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
|
132 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
133 |
+
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
|
134 |
+
#else
|
135 |
+
return __hmul(x, y);
|
136 |
+
#endif
|
137 |
+
}
|
138 |
+
|
139 |
+
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
|
140 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
141 |
+
float fxl, fxh, fyl, fyh, fzl, fzh;
|
142 |
+
fxl = __low2float(x);
|
143 |
+
fxh = __high2float(x);
|
144 |
+
fyl = __low2float(y);
|
145 |
+
fyh = __high2float(y);
|
146 |
+
fzl = __low2float(z);
|
147 |
+
fzh = __high2float(z);
|
148 |
+
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
|
149 |
+
#else
|
150 |
+
return __hfma2(x, y, z);
|
151 |
+
#endif
|
152 |
+
}
|
153 |
+
|
154 |
+
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
|
155 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
156 |
+
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
|
157 |
+
#else
|
158 |
+
return __hfma(x, y, z);
|
159 |
+
#endif
|
160 |
+
}
|
161 |
+
|
162 |
+
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
|
163 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
164 |
+
float fxl, fxh;
|
165 |
+
fxl = __low2float(x);
|
166 |
+
fxh = __high2float(x);;
|
167 |
+
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
|
168 |
+
#else
|
169 |
+
return h2exp(x);
|
170 |
+
#endif
|
171 |
+
}
|
172 |
+
|
173 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
174 |
+
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
|
175 |
+
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
|
176 |
+
|
177 |
+
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
178 |
+
{
|
179 |
+
__nv_bfloat162 t; t.x = x; t.y = y; return t;
|
180 |
+
}
|
181 |
+
|
182 |
+
#endif
|
183 |
+
|
184 |
+
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
|
185 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
186 |
+
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
|
187 |
+
#else
|
188 |
+
return a + b + c;
|
189 |
+
#endif
|
190 |
+
}
|
191 |
+
|
192 |
+
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
|
193 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
194 |
+
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
|
195 |
+
#else
|
196 |
+
return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
|
197 |
+
#endif
|
198 |
+
}
|
199 |
+
|
200 |
+
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
201 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
202 |
+
float fal, fah, fbl, fbh, fcl, fch;
|
203 |
+
fal = __low2float(a);
|
204 |
+
fah = __high2float(a);
|
205 |
+
fbl = __low2float(b);
|
206 |
+
fbh = __high2float(b);
|
207 |
+
fcl = __low2float(c);
|
208 |
+
fch = __high2float(c);
|
209 |
+
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
|
210 |
+
#else
|
211 |
+
return a + b + c;
|
212 |
+
#endif
|
213 |
+
}
|
214 |
+
|
215 |
+
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
|
216 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
217 |
+
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
|
218 |
+
#else
|
219 |
+
return a * b * c;
|
220 |
+
#endif
|
221 |
+
}
|
222 |
+
|
223 |
+
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
224 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
225 |
+
float fal, fah, fbl, fbh, fcl, fch;
|
226 |
+
fal = __low2float(a);
|
227 |
+
fah = __high2float(a);
|
228 |
+
fbl = __low2float(b);
|
229 |
+
fbh = __high2float(b);
|
230 |
+
fcl = __low2float(c);
|
231 |
+
fch = __high2float(c);
|
232 |
+
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
|
233 |
+
#else
|
234 |
+
return a * b * c;
|
235 |
+
#endif
|
236 |
+
}
|
237 |
+
|
238 |
+
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
|
239 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
240 |
+
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
|
241 |
+
fal = __low2float(a);
|
242 |
+
fah = __high2float(a);
|
243 |
+
fbl = __low2float(b);
|
244 |
+
fbh = __high2float(b);
|
245 |
+
fcl = __low2float(c);
|
246 |
+
fch = __high2float(c);
|
247 |
+
fdl = __low2float(d);
|
248 |
+
fdh = __high2float(d);
|
249 |
+
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
|
250 |
+
#else
|
251 |
+
return a * b * c + d;
|
252 |
+
#endif
|
253 |
+
}
|
254 |
+
|
255 |
+
#endif // ENABLE_BF16
|
256 |
+
|
257 |
+
} // namespace fastertransformer
|
cuda_bf16_wrapper.h
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#ifdef ENABLE_BF16
|
22 |
+
#include <cuda_bf16.h>
|
23 |
+
#endif
|
ddp.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- default.yaml
|
3 |
+
|
4 |
+
accelerator: gpu
|
5 |
+
devices: 4
|
6 |
+
strategy: ddp
|
debug.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# run in debug mode with:
|
4 |
+
# `python run.py mode=debug`
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /trainer: debug.yaml
|
8 |
+
|
9 |
+
debug_mode: True
|
10 |
+
|
11 |
+
hydra:
|
12 |
+
# sets level of all command line loggers to 'DEBUG'
|
13 |
+
verbose: True
|
14 |
+
|
15 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
16 |
+
# sets level of only chosen command line loggers to 'DEBUG'
|
17 |
+
# verbose: [src.train, src.utils.utils]
|
18 |
+
|
19 |
+
# sets output paths for all file logs to 'logs/debug/'
|
20 |
+
run:
|
21 |
+
dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
22 |
+
sweep:
|
23 |
+
dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S}
|
24 |
+
subdir: ${hydra.job.num}
|
25 |
+
|
26 |
+
# disable rich config printing, since it will be already printed by hydra when `verbose: True`
|
27 |
+
print_config: False
|
decoder_masked_multihead_attention.cu
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#include "decoder_masked_multihead_attention.h"
|
20 |
+
#include "decoder_masked_multihead_attention_utils.h"
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include <assert.h>
|
23 |
+
#include <float.h>
|
24 |
+
#include <type_traits>
|
25 |
+
|
26 |
+
#include "decoder_masked_multihead_attention_template.hpp"
|
27 |
+
|
28 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
29 |
+
|
30 |
+
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
|
31 |
+
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
|
32 |
+
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
|
33 |
+
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
|
34 |
+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
|
35 |
+
dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \
|
36 |
+
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
|
37 |
+
|
38 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
39 |
+
|
40 |
+
// !!! Specialize the launcher for Cross attention
|
41 |
+
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
|
42 |
+
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
|
43 |
+
{
|
44 |
+
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
|
45 |
+
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
|
46 |
+
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
|
47 |
+
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
|
48 |
+
if (tlength < 32) {
|
49 |
+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
|
50 |
+
}
|
51 |
+
else if (tlength < 2048) {
|
52 |
+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
|
53 |
+
}
|
54 |
+
else {
|
55 |
+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
60 |
+
|
61 |
+
#undef MMHA_LAUNCH_KERNEL
|
62 |
+
|
63 |
+
template<typename T, typename KERNEL_PARAMS_TYPE>
|
64 |
+
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
|
65 |
+
{
|
66 |
+
switch (params.hidden_size_per_head) {
|
67 |
+
case 32:
|
68 |
+
mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
|
69 |
+
break;
|
70 |
+
case 48:
|
71 |
+
mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
|
72 |
+
break;
|
73 |
+
case 64:
|
74 |
+
mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
|
75 |
+
break;
|
76 |
+
case 80:
|
77 |
+
mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
78 |
+
break;
|
79 |
+
case 96:
|
80 |
+
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
81 |
+
break;
|
82 |
+
case 128:
|
83 |
+
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
84 |
+
break;
|
85 |
+
case 160:
|
86 |
+
mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
87 |
+
break;
|
88 |
+
case 192:
|
89 |
+
mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
90 |
+
break;
|
91 |
+
case 224:
|
92 |
+
mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
93 |
+
break;
|
94 |
+
case 256:
|
95 |
+
mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
96 |
+
break;
|
97 |
+
default:
|
98 |
+
assert(false);
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
103 |
+
|
104 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
|
105 |
+
{
|
106 |
+
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
|
107 |
+
}
|
108 |
+
|
109 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
110 |
+
|
111 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
|
112 |
+
{
|
113 |
+
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
|
114 |
+
}
|
115 |
+
|
116 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
117 |
+
|
118 |
+
#ifdef ENABLE_BF16
|
119 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
|
120 |
+
const cudaStream_t& stream)
|
121 |
+
{
|
122 |
+
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
|
123 |
+
}
|
124 |
+
#endif
|
125 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
126 |
+
|
127 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
|
128 |
+
{
|
129 |
+
multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
|
130 |
+
}
|
131 |
+
|
132 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
133 |
+
|
134 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
|
135 |
+
{
|
136 |
+
multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
|
137 |
+
}
|
138 |
+
|
139 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
140 |
+
|
141 |
+
#ifdef ENABLE_BF16
|
142 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
|
143 |
+
const cudaStream_t& stream)
|
144 |
+
{
|
145 |
+
multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
|
146 |
+
}
|
147 |
+
#endif
|
148 |
+
|
149 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
decoder_masked_multihead_attention.h
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include <cuda_fp16.h>
|
23 |
+
#include <cuda_runtime_api.h>
|
24 |
+
#include <stdint.h>
|
25 |
+
#include <stdio.h>
|
26 |
+
#include <stdlib.h>
|
27 |
+
|
28 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
29 |
+
|
30 |
+
#define CHECK_CUDA(call) \
|
31 |
+
do { \
|
32 |
+
cudaError_t status_ = call; \
|
33 |
+
if (status_ != cudaSuccess) { \
|
34 |
+
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
35 |
+
exit(1); \
|
36 |
+
} \
|
37 |
+
} while (0)
|
38 |
+
|
39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
40 |
+
|
41 |
+
// The structure of parameters for the masked multihead attention kernel.
|
42 |
+
//
|
43 |
+
// We use the following terminology to describe the different dimensions.
|
44 |
+
//
|
45 |
+
// B: Batch size (number of sequences),
|
46 |
+
// L: Sequence length,
|
47 |
+
// D: Hidden dimension,
|
48 |
+
// H: Number of heads,
|
49 |
+
// Dh: Hidden dimension per head - Dh = D / H.
|
50 |
+
|
51 |
+
template<typename T>
|
52 |
+
struct Multihead_attention_params_base {
|
53 |
+
|
54 |
+
// The output buffer. Dimensions B x D.
|
55 |
+
T* out = nullptr;
|
56 |
+
|
57 |
+
// The input Qs and the associated bias. Dimensions B x D and D, resp.
|
58 |
+
const T *q = nullptr, *q_bias = nullptr;
|
59 |
+
// The input Ks and the associated bias. Dimensions B x D and D, resp.
|
60 |
+
const T *k = nullptr, *k_bias = nullptr;
|
61 |
+
// The input Vs and the associated bias. Dimensions B x D and D, resp.
|
62 |
+
const T *v = nullptr, *v_bias = nullptr;
|
63 |
+
|
64 |
+
// The cache for the Ks. The size must be at least B x L x D.
|
65 |
+
T* k_cache = nullptr;
|
66 |
+
// The cache for the Vs. The size must be at least B x L x D.
|
67 |
+
T* v_cache = nullptr;
|
68 |
+
// The indirections to use for cache when beam sampling.
|
69 |
+
const int* cache_indir = nullptr;
|
70 |
+
|
71 |
+
// Stride to handle the case when KQV is a single buffer
|
72 |
+
int stride_q = 0;
|
73 |
+
int stride_k = 0;
|
74 |
+
int stride_v = 0;
|
75 |
+
|
76 |
+
// The batch size.
|
77 |
+
int batch_size = 0;
|
78 |
+
// The beam width
|
79 |
+
int beam_width = 0;
|
80 |
+
// The sequence length.
|
81 |
+
int memory_max_len = 0;
|
82 |
+
// The number of heads (H).
|
83 |
+
int num_heads = 0;
|
84 |
+
int num_heads_kv = 0;
|
85 |
+
int num_heads_q_kv_ratio = 0;
|
86 |
+
// The hidden dimension per head (Dh).
|
87 |
+
int hidden_size_per_head = 0;
|
88 |
+
// The per-head latent space reserved for rotary embeddings.
|
89 |
+
int rotary_embedding_dim = 0;
|
90 |
+
bool neox_rotary_style = false;
|
91 |
+
float rotary_base = 0.0f;
|
92 |
+
// The maximum length of input sentences.
|
93 |
+
int max_input_length = 0;
|
94 |
+
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
|
95 |
+
int timestep = 0;
|
96 |
+
// The current timestep of each sentences (support different timestep for different sentences)
|
97 |
+
|
98 |
+
// The 1.f / sqrt(Dh). Computed on the host.
|
99 |
+
float inv_sqrt_dh = 0.0f;
|
100 |
+
|
101 |
+
// Used when we have some input context like gpt
|
102 |
+
const int* total_padding_tokens = nullptr;
|
103 |
+
|
104 |
+
const bool* masked_tokens = nullptr;
|
105 |
+
const int* prefix_prompt_lengths = nullptr;
|
106 |
+
int max_prefix_prompt_length = 0;
|
107 |
+
|
108 |
+
const T* relative_attention_bias = nullptr;
|
109 |
+
int relative_attention_bias_stride = 0;
|
110 |
+
// The slope per head of linear position bias to attention score (H).
|
111 |
+
const T* linear_bias_slopes = nullptr;
|
112 |
+
|
113 |
+
const T* ia3_key_weights = nullptr;
|
114 |
+
const T* ia3_value_weights = nullptr;
|
115 |
+
const int* ia3_tasks = nullptr;
|
116 |
+
|
117 |
+
const float* qkv_scale_out = nullptr;
|
118 |
+
const float* attention_out_scale = nullptr;
|
119 |
+
int int8_mode = 0;
|
120 |
+
|
121 |
+
const T *rotary_cos = nullptr;
|
122 |
+
const T *rotary_sin = nullptr;
|
123 |
+
|
124 |
+
const int *nnz_head_idx = nullptr;
|
125 |
+
int nnz_heads = 0;
|
126 |
+
};
|
127 |
+
|
128 |
+
template<typename T, bool CROSS_ATTENTION>
|
129 |
+
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
|
130 |
+
// output cross attentions
|
131 |
+
float* cross_attention_out = nullptr;
|
132 |
+
int max_decoder_seq_len = 0;
|
133 |
+
bool is_return_cross_attentions = false;
|
134 |
+
|
135 |
+
// allows to exist attention eary
|
136 |
+
bool* finished = nullptr;
|
137 |
+
|
138 |
+
// required in case of cross attention
|
139 |
+
// will need it here till if constexpr in c++17
|
140 |
+
int* memory_length_per_sample = nullptr;
|
141 |
+
|
142 |
+
// required in case of masked attention with different length
|
143 |
+
const int* length_per_sample = nullptr;
|
144 |
+
};
|
145 |
+
|
146 |
+
template<typename T>
|
147 |
+
struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
|
148 |
+
// output cross attentions
|
149 |
+
float* cross_attention_out = nullptr;
|
150 |
+
int max_decoder_seq_len = 0;
|
151 |
+
bool is_return_cross_attentions = false;
|
152 |
+
|
153 |
+
// allows to exist attention eary
|
154 |
+
bool* finished = nullptr;
|
155 |
+
|
156 |
+
// required in case of cross attention
|
157 |
+
int* memory_length_per_sample = nullptr;
|
158 |
+
|
159 |
+
// required in case of masked attention with different length
|
160 |
+
const int* length_per_sample = nullptr;
|
161 |
+
};
|
162 |
+
|
163 |
+
template<class T>
|
164 |
+
using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
|
165 |
+
|
166 |
+
template<class T>
|
167 |
+
using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
|
168 |
+
|
169 |
+
template<typename T>
|
170 |
+
struct outputCrossAttentionParam {
|
171 |
+
// max decoder output length
|
172 |
+
int max_decoder_seq_len = 0;
|
173 |
+
T* cross_attention_out = nullptr;
|
174 |
+
bool is_return_cross_attentions = false;
|
175 |
+
};
|
176 |
+
|
177 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
178 |
+
|
179 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
|
180 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
|
181 |
+
#ifdef ENABLE_BF16
|
182 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
|
183 |
+
const cudaStream_t& stream);
|
184 |
+
#endif
|
185 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
|
186 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
|
187 |
+
#ifdef ENABLE_BF16
|
188 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
|
189 |
+
const cudaStream_t& stream);
|
190 |
+
#endif
|
191 |
+
|
192 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
decoder_masked_multihead_attention_template.hpp
ADDED
@@ -0,0 +1,1619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
#pragma once
|
19 |
+
|
20 |
+
#include "decoder_masked_multihead_attention.h"
|
21 |
+
#include "decoder_masked_multihead_attention_utils.h"
|
22 |
+
#include "cuda_bf16_wrapper.h"
|
23 |
+
#include "cuda_bf16_fallbacks.cuh"
|
24 |
+
#include <assert.h>
|
25 |
+
#include <float.h>
|
26 |
+
#include <type_traits>
|
27 |
+
|
28 |
+
// #define MMHA_USE_HMMA_FOR_REDUCTION
|
29 |
+
|
30 |
+
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
31 |
+
|
32 |
+
// Does not seem to affect the accuracy that much
|
33 |
+
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
34 |
+
|
35 |
+
// Seems to slightly improve the accuracy
|
36 |
+
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
37 |
+
|
38 |
+
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
|
39 |
+
// Does not seem to improve the accuracy
|
40 |
+
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
|
41 |
+
#endif
|
42 |
+
|
43 |
+
namespace mmha {
|
44 |
+
|
45 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
46 |
+
|
47 |
+
//
|
48 |
+
// We use the following terminology to describe the different dimensions.
|
49 |
+
//
|
50 |
+
// B: Batch size (number of sequences),
|
51 |
+
// L: Sequence length,
|
52 |
+
// D: Hidden dimension,
|
53 |
+
// H: Number of heads,
|
54 |
+
// Dh: Hidden dimension per head - Dh = D / H.
|
55 |
+
//
|
56 |
+
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
|
57 |
+
// 64, 128 and 256 threads per block.
|
58 |
+
//
|
59 |
+
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
|
60 |
+
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
|
61 |
+
// cache buffer helps with memory accesses and contains keys with bias.
|
62 |
+
//
|
63 |
+
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
|
64 |
+
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
|
65 |
+
// values for x are chosen to create chunks of 16 bytes.
|
66 |
+
//
|
67 |
+
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
|
68 |
+
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
|
69 |
+
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
|
70 |
+
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
|
71 |
+
//
|
72 |
+
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
|
73 |
+
// shared memory.
|
74 |
+
//
|
75 |
+
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
|
76 |
+
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
|
77 |
+
// except for the current timestep. The layout of the cache buffer for the values is much simpler
|
78 |
+
// as it is [B, H, L, Dh].
|
79 |
+
//
|
80 |
+
|
81 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
82 |
+
|
83 |
+
template<typename T, int Dh>
|
84 |
+
struct Qk_vec_ {
|
85 |
+
};
|
86 |
+
|
87 |
+
template<>
|
88 |
+
struct Qk_vec_<float, 32> {
|
89 |
+
using Type = float;
|
90 |
+
};
|
91 |
+
template<>
|
92 |
+
struct Qk_vec_<float, 64> {
|
93 |
+
using Type = float2;
|
94 |
+
};
|
95 |
+
template<>
|
96 |
+
struct Qk_vec_<float, 128> {
|
97 |
+
using Type = float4;
|
98 |
+
};
|
99 |
+
template<>
|
100 |
+
struct Qk_vec_<float, 256> {
|
101 |
+
using Type = float4;
|
102 |
+
};
|
103 |
+
template<>
|
104 |
+
struct Qk_vec_<uint16_t, 32> {
|
105 |
+
using Type = uint32_t;
|
106 |
+
};
|
107 |
+
template<>
|
108 |
+
struct Qk_vec_<uint16_t, 64> {
|
109 |
+
using Type = uint32_t;
|
110 |
+
};
|
111 |
+
template<>
|
112 |
+
struct Qk_vec_<uint16_t, 128> {
|
113 |
+
using Type = uint2;
|
114 |
+
};
|
115 |
+
template<>
|
116 |
+
struct Qk_vec_<uint16_t, 256> {
|
117 |
+
using Type = uint4;
|
118 |
+
};
|
119 |
+
#ifdef ENABLE_BF16
|
120 |
+
template<>
|
121 |
+
struct Qk_vec_<__nv_bfloat16, 32> {
|
122 |
+
using Type = __nv_bfloat162;
|
123 |
+
};
|
124 |
+
template<>
|
125 |
+
struct Qk_vec_<__nv_bfloat16, 64> {
|
126 |
+
using Type = __nv_bfloat162;
|
127 |
+
};
|
128 |
+
template<>
|
129 |
+
struct Qk_vec_<__nv_bfloat16, 128> {
|
130 |
+
using Type = bf16_4_t;
|
131 |
+
};
|
132 |
+
template<>
|
133 |
+
struct Qk_vec_<__nv_bfloat16, 256> {
|
134 |
+
using Type = bf16_8_t;
|
135 |
+
};
|
136 |
+
#endif // ENABLE_BF16
|
137 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
138 |
+
|
139 |
+
template<typename T, int THREADS_PER_KEY>
|
140 |
+
struct K_vec_ {
|
141 |
+
};
|
142 |
+
|
143 |
+
template<>
|
144 |
+
struct K_vec_<float, 4> {
|
145 |
+
using Type = float;
|
146 |
+
};
|
147 |
+
template<>
|
148 |
+
struct K_vec_<float, 2> {
|
149 |
+
using Type = float2;
|
150 |
+
};
|
151 |
+
template<>
|
152 |
+
struct K_vec_<float, 1> {
|
153 |
+
using Type = float4;
|
154 |
+
};
|
155 |
+
template<>
|
156 |
+
struct K_vec_<uint16_t, 4> {
|
157 |
+
using Type = uint32_t;
|
158 |
+
};
|
159 |
+
template<>
|
160 |
+
struct K_vec_<uint16_t, 2> {
|
161 |
+
using Type = uint2;
|
162 |
+
};
|
163 |
+
template<>
|
164 |
+
struct K_vec_<uint16_t, 1> {
|
165 |
+
using Type = uint4;
|
166 |
+
};
|
167 |
+
#ifdef ENABLE_BF16
|
168 |
+
template<>
|
169 |
+
struct K_vec_<__nv_bfloat16, 4> {
|
170 |
+
using Type = __nv_bfloat162;
|
171 |
+
};
|
172 |
+
template<>
|
173 |
+
struct K_vec_<__nv_bfloat16, 2> {
|
174 |
+
using Type = bf16_4_t;
|
175 |
+
};
|
176 |
+
template<>
|
177 |
+
struct K_vec_<__nv_bfloat16, 1> {
|
178 |
+
using Type = bf16_8_t;
|
179 |
+
};
|
180 |
+
#endif // ENABLE_BF16
|
181 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
182 |
+
|
183 |
+
template<typename T, int V_VEC_SIZE>
|
184 |
+
struct V_vec_ {
|
185 |
+
};
|
186 |
+
|
187 |
+
template<>
|
188 |
+
struct V_vec_<float, 1> {
|
189 |
+
using Type = float;
|
190 |
+
};
|
191 |
+
template<>
|
192 |
+
struct V_vec_<float, 2> {
|
193 |
+
using Type = float2;
|
194 |
+
};
|
195 |
+
template<>
|
196 |
+
struct V_vec_<float, 4> {
|
197 |
+
using Type = float4;
|
198 |
+
};
|
199 |
+
template<>
|
200 |
+
struct V_vec_<uint16_t, 2> {
|
201 |
+
using Type = uint32_t;
|
202 |
+
};
|
203 |
+
template<>
|
204 |
+
struct V_vec_<uint16_t, 4> {
|
205 |
+
using Type = uint2;
|
206 |
+
};
|
207 |
+
template<>
|
208 |
+
struct V_vec_<uint16_t, 8> {
|
209 |
+
using Type = uint4;
|
210 |
+
};
|
211 |
+
#ifdef ENABLE_BF16
|
212 |
+
template<>
|
213 |
+
struct V_vec_<__nv_bfloat16, 2> {
|
214 |
+
using Type = __nv_bfloat162;
|
215 |
+
};
|
216 |
+
template<>
|
217 |
+
struct V_vec_<__nv_bfloat16, 4> {
|
218 |
+
using Type = bf16_4_t;
|
219 |
+
};
|
220 |
+
template<>
|
221 |
+
struct V_vec_<__nv_bfloat16, 8> {
|
222 |
+
using Type = bf16_8_t;
|
223 |
+
};
|
224 |
+
#endif // ENABLE_BF16
|
225 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
226 |
+
|
227 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
228 |
+
template<typename T>
|
229 |
+
struct Qk_vec_acum_fp32_ {
|
230 |
+
};
|
231 |
+
|
232 |
+
template<>
|
233 |
+
struct Qk_vec_acum_fp32_<float> {
|
234 |
+
using Type = float;
|
235 |
+
};
|
236 |
+
template<>
|
237 |
+
struct Qk_vec_acum_fp32_<float2> {
|
238 |
+
using Type = float2;
|
239 |
+
};
|
240 |
+
template<>
|
241 |
+
struct Qk_vec_acum_fp32_<float4> {
|
242 |
+
using Type = float4;
|
243 |
+
};
|
244 |
+
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
|
245 |
+
template<>
|
246 |
+
struct Qk_vec_acum_fp32_<uint32_t> {
|
247 |
+
using Type = float2;
|
248 |
+
};
|
249 |
+
template<>
|
250 |
+
struct Qk_vec_acum_fp32_<uint2> {
|
251 |
+
using Type = Float4_;
|
252 |
+
};
|
253 |
+
template<>
|
254 |
+
struct Qk_vec_acum_fp32_<uint4> {
|
255 |
+
using Type = Float8_;
|
256 |
+
};
|
257 |
+
template<>
|
258 |
+
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
|
259 |
+
using Type = float;
|
260 |
+
};
|
261 |
+
template<>
|
262 |
+
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
|
263 |
+
using Type = float2;
|
264 |
+
};
|
265 |
+
template<>
|
266 |
+
struct Qk_vec_acum_fp32_<bf16_4_t> {
|
267 |
+
using Type = Float4_;
|
268 |
+
};
|
269 |
+
template<>
|
270 |
+
struct Qk_vec_acum_fp32_<bf16_8_t> {
|
271 |
+
using Type = Float8_;
|
272 |
+
};
|
273 |
+
|
274 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
275 |
+
|
276 |
+
template<typename T>
|
277 |
+
struct K_vec_acum_fp32_ {
|
278 |
+
};
|
279 |
+
|
280 |
+
template<>
|
281 |
+
struct K_vec_acum_fp32_<float> {
|
282 |
+
using Type = float;
|
283 |
+
};
|
284 |
+
template<>
|
285 |
+
struct K_vec_acum_fp32_<float2> {
|
286 |
+
using Type = float2;
|
287 |
+
};
|
288 |
+
template<>
|
289 |
+
struct K_vec_acum_fp32_<float4> {
|
290 |
+
using Type = float4;
|
291 |
+
};
|
292 |
+
template<>
|
293 |
+
struct K_vec_acum_fp32_<uint32_t> {
|
294 |
+
using Type = float2;
|
295 |
+
};
|
296 |
+
template<>
|
297 |
+
struct K_vec_acum_fp32_<uint2> {
|
298 |
+
using Type = Float4_;
|
299 |
+
};
|
300 |
+
template<>
|
301 |
+
struct K_vec_acum_fp32_<uint4> {
|
302 |
+
using Type = Float8_;
|
303 |
+
};
|
304 |
+
template<>
|
305 |
+
struct K_vec_acum_fp32_<__nv_bfloat16> {
|
306 |
+
using Type = float;
|
307 |
+
};
|
308 |
+
template<>
|
309 |
+
struct K_vec_acum_fp32_<__nv_bfloat162> {
|
310 |
+
using Type = float2;
|
311 |
+
};
|
312 |
+
template<>
|
313 |
+
struct K_vec_acum_fp32_<bf16_4_t> {
|
314 |
+
using Type = Float4_;
|
315 |
+
};
|
316 |
+
template<>
|
317 |
+
struct K_vec_acum_fp32_<bf16_8_t> {
|
318 |
+
using Type = Float8_;
|
319 |
+
};
|
320 |
+
#endif
|
321 |
+
|
322 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
323 |
+
|
324 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
325 |
+
template<typename T>
|
326 |
+
struct V_vec_acum_fp32_ {
|
327 |
+
};
|
328 |
+
|
329 |
+
template<>
|
330 |
+
struct V_vec_acum_fp32_<float> {
|
331 |
+
using Type = float;
|
332 |
+
};
|
333 |
+
template<>
|
334 |
+
struct V_vec_acum_fp32_<float2> {
|
335 |
+
using Type = float2;
|
336 |
+
};
|
337 |
+
template<>
|
338 |
+
struct V_vec_acum_fp32_<float4> {
|
339 |
+
using Type = float4;
|
340 |
+
};
|
341 |
+
template<>
|
342 |
+
struct V_vec_acum_fp32_<uint32_t> {
|
343 |
+
using Type = float2;
|
344 |
+
};
|
345 |
+
template<>
|
346 |
+
struct V_vec_acum_fp32_<uint2> {
|
347 |
+
using Type = Float4_;
|
348 |
+
};
|
349 |
+
template<>
|
350 |
+
struct V_vec_acum_fp32_<uint4> {
|
351 |
+
using Type = Float8_;
|
352 |
+
};
|
353 |
+
#ifdef ENABLE_BF16
|
354 |
+
template<>
|
355 |
+
struct V_vec_acum_fp32_<__nv_bfloat162> {
|
356 |
+
using Type = float2;
|
357 |
+
};
|
358 |
+
template<>
|
359 |
+
struct V_vec_acum_fp32_<bf16_4_t> {
|
360 |
+
using Type = Float4_;
|
361 |
+
};
|
362 |
+
template<>
|
363 |
+
struct V_vec_acum_fp32_<bf16_8_t> {
|
364 |
+
using Type = Float8_;
|
365 |
+
};
|
366 |
+
#endif // ENABLE_BF16
|
367 |
+
#endif
|
368 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
369 |
+
|
370 |
+
template<int THREADS_PER_KEY, typename K_vec, int N>
|
371 |
+
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
|
372 |
+
{
|
373 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
374 |
+
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
|
375 |
+
#else
|
376 |
+
using K_vec_acum = K_vec;
|
377 |
+
#endif
|
378 |
+
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
379 |
+
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
|
380 |
+
#pragma unroll
|
381 |
+
for (int ii = 1; ii < N; ++ii) {
|
382 |
+
qk_vec = fma(q[ii], k[ii], qk_vec);
|
383 |
+
}
|
384 |
+
|
385 |
+
// Finalize the reduction across lanes.
|
386 |
+
float qk = sum(qk_vec);
|
387 |
+
#pragma unroll
|
388 |
+
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
|
389 |
+
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
390 |
+
}
|
391 |
+
return qk;
|
392 |
+
}
|
393 |
+
|
394 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
395 |
+
|
396 |
+
template<typename T, int THREADS_PER_KEY>
|
397 |
+
struct Qk_dot {
|
398 |
+
template<typename K_vec, int N>
|
399 |
+
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
|
400 |
+
{
|
401 |
+
return qk_dot_<THREADS_PER_KEY>(q, k);
|
402 |
+
}
|
403 |
+
};
|
404 |
+
|
405 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
406 |
+
|
407 |
+
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
|
408 |
+
{
|
409 |
+
float4 c;
|
410 |
+
float zero = 0.f;
|
411 |
+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
|
412 |
+
" {%0, %1, %2, %3}, \n"
|
413 |
+
" {%4, %5}, \n"
|
414 |
+
" {%6}, \n"
|
415 |
+
" {%7, %7, %7, %7}; \n"
|
416 |
+
|
417 |
+
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
|
418 |
+
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
|
419 |
+
return c;
|
420 |
+
}
|
421 |
+
|
422 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
423 |
+
|
424 |
+
template<int N>
|
425 |
+
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
426 |
+
{
|
427 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
428 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
429 |
+
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
|
430 |
+
#else
|
431 |
+
using K_vec_acum = uint32_t;
|
432 |
+
#endif
|
433 |
+
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
|
434 |
+
#pragma unroll
|
435 |
+
for (int ii = 1; ii < N; ++ii) {
|
436 |
+
qk_vec = fma(q[ii], k[ii], qk_vec);
|
437 |
+
}
|
438 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
439 |
+
uint32_t qk_vec_ = float2_to_half2(qk_vec);
|
440 |
+
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
|
441 |
+
#else
|
442 |
+
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
|
443 |
+
#endif
|
444 |
+
#else
|
445 |
+
return 0.f;
|
446 |
+
#endif
|
447 |
+
}
|
448 |
+
|
449 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
450 |
+
|
451 |
+
template<>
|
452 |
+
struct Qk_dot<uint16_t, 4> {
|
453 |
+
template<int N>
|
454 |
+
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
455 |
+
{
|
456 |
+
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
|
457 |
+
return qk_hmma_dot_(q, k);
|
458 |
+
#else
|
459 |
+
return qk_dot_<4>(q, k);
|
460 |
+
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
|
461 |
+
}
|
462 |
+
};
|
463 |
+
|
464 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
465 |
+
|
466 |
+
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
467 |
+
inline __device__ float block_sum(float* red_smem, float sum)
|
468 |
+
{
|
469 |
+
|
470 |
+
// Decompose the thread index into warp / lane.
|
471 |
+
int warp = threadIdx.x / WARP_SIZE;
|
472 |
+
int lane = threadIdx.x % WARP_SIZE;
|
473 |
+
|
474 |
+
// Compute the sum per warp.
|
475 |
+
#pragma unroll
|
476 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
477 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
478 |
+
}
|
479 |
+
|
480 |
+
// Warp leaders store the data to shared memory.
|
481 |
+
if (lane == 0) {
|
482 |
+
red_smem[warp] = sum;
|
483 |
+
}
|
484 |
+
|
485 |
+
// Make sure the data is in shared memory.
|
486 |
+
__syncthreads();
|
487 |
+
|
488 |
+
// The warps compute the final sums.
|
489 |
+
if (lane < WARPS_PER_BLOCK) {
|
490 |
+
sum = red_smem[lane];
|
491 |
+
}
|
492 |
+
|
493 |
+
// Parallel reduction inside the warp.
|
494 |
+
#pragma unroll
|
495 |
+
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
496 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
497 |
+
}
|
498 |
+
|
499 |
+
// Broadcast to other threads.
|
500 |
+
return __shfl_sync(uint32_t(-1), sum, 0);
|
501 |
+
}
|
502 |
+
|
503 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
504 |
+
|
505 |
+
inline __device__ void convert_from_float(float& dst, float src)
|
506 |
+
{
|
507 |
+
dst = src;
|
508 |
+
}
|
509 |
+
|
510 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
511 |
+
|
512 |
+
inline __device__ void convert_from_float(uint16_t& dst, float src)
|
513 |
+
{
|
514 |
+
dst = float_to_half(src);
|
515 |
+
}
|
516 |
+
|
517 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
518 |
+
|
519 |
+
inline __device__ void convert_from_float(uint32_t& dst, float2 src)
|
520 |
+
{
|
521 |
+
dst = float2_to_half2(src);
|
522 |
+
}
|
523 |
+
|
524 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
525 |
+
#ifdef ENABLE_BF16
|
526 |
+
inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
|
527 |
+
{
|
528 |
+
dst = __float2bfloat16(src);
|
529 |
+
}
|
530 |
+
|
531 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
532 |
+
|
533 |
+
inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
|
534 |
+
{
|
535 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
536 |
+
dst = __float22bfloat162_rn(src);
|
537 |
+
#else
|
538 |
+
dst = __floats2bfloat162_rn(src.x, src.y);
|
539 |
+
#endif
|
540 |
+
}
|
541 |
+
#endif // ENABLE_BF16
|
542 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
543 |
+
|
544 |
+
inline __device__ void convert_from_float(uint2& dst, Float4_ src)
|
545 |
+
{
|
546 |
+
dst.x = float2_to_half2(src.x);
|
547 |
+
dst.y = float2_to_half2(src.y);
|
548 |
+
}
|
549 |
+
|
550 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
551 |
+
|
552 |
+
inline __device__ void convert_from_float(uint2& dst, float4 src)
|
553 |
+
{
|
554 |
+
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
|
555 |
+
}
|
556 |
+
|
557 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
558 |
+
|
559 |
+
inline __device__ void convert_from_float(uint4& dst, Float8_ src)
|
560 |
+
{
|
561 |
+
dst.x = float2_to_half2(src.x);
|
562 |
+
dst.y = float2_to_half2(src.y);
|
563 |
+
dst.z = float2_to_half2(src.z);
|
564 |
+
dst.w = float2_to_half2(src.w);
|
565 |
+
}
|
566 |
+
|
567 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
568 |
+
|
569 |
+
#ifdef ENABLE_BF16
|
570 |
+
inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
|
571 |
+
{
|
572 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
573 |
+
dst.x = __float22bfloat162_rn(src.x);
|
574 |
+
dst.y = __float22bfloat162_rn(src.y);
|
575 |
+
#else
|
576 |
+
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
|
577 |
+
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
|
578 |
+
#endif
|
579 |
+
}
|
580 |
+
|
581 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
582 |
+
|
583 |
+
inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
|
584 |
+
{
|
585 |
+
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
|
586 |
+
}
|
587 |
+
|
588 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
589 |
+
|
590 |
+
inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
|
591 |
+
{
|
592 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
593 |
+
dst.x = __float22bfloat162_rn(src.x);
|
594 |
+
dst.y = __float22bfloat162_rn(src.y);
|
595 |
+
dst.z = __float22bfloat162_rn(src.z);
|
596 |
+
dst.w = __float22bfloat162_rn(src.w);
|
597 |
+
#else
|
598 |
+
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
|
599 |
+
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
|
600 |
+
dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
|
601 |
+
dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
|
602 |
+
#endif
|
603 |
+
}
|
604 |
+
#endif // ENABLE_BF16
|
605 |
+
|
606 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
607 |
+
|
608 |
+
inline __device__ void convert_from_float(float2& dst, float2 src)
|
609 |
+
{
|
610 |
+
dst = src;
|
611 |
+
}
|
612 |
+
|
613 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
614 |
+
|
615 |
+
inline __device__ void convert_from_float(float4& dst, float4 src)
|
616 |
+
{
|
617 |
+
dst = src;
|
618 |
+
}
|
619 |
+
|
620 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
621 |
+
|
622 |
+
inline __device__ float convert_to_float(float4 u)
|
623 |
+
{
|
624 |
+
return u.x;
|
625 |
+
}
|
626 |
+
|
627 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
628 |
+
|
629 |
+
inline __device__ float convert_to_float(uint4 u)
|
630 |
+
{
|
631 |
+
float2 tmp = half2_to_float2(u.x);
|
632 |
+
return tmp.x;
|
633 |
+
}
|
634 |
+
|
635 |
+
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
636 |
+
|
637 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
638 |
+
|
639 |
+
inline __device__ float cast_to_float(float u)
|
640 |
+
{
|
641 |
+
return u;
|
642 |
+
}
|
643 |
+
|
644 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
645 |
+
|
646 |
+
inline __device__ float2 cast_to_float(float2 u)
|
647 |
+
{
|
648 |
+
return u;
|
649 |
+
}
|
650 |
+
|
651 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
652 |
+
|
653 |
+
inline __device__ float4 cast_to_float(float4 u)
|
654 |
+
{
|
655 |
+
return u;
|
656 |
+
}
|
657 |
+
|
658 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
659 |
+
|
660 |
+
inline __device__ Float4_ cast_to_float(Float4_ u)
|
661 |
+
{
|
662 |
+
return u;
|
663 |
+
}
|
664 |
+
|
665 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
666 |
+
|
667 |
+
inline __device__ Float8_ cast_to_float(Float8_ u)
|
668 |
+
{
|
669 |
+
return u;
|
670 |
+
}
|
671 |
+
|
672 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
673 |
+
|
674 |
+
inline __device__ float2 cast_to_float(uint32_t u)
|
675 |
+
{
|
676 |
+
return half2_to_float2(u);
|
677 |
+
}
|
678 |
+
|
679 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
680 |
+
|
681 |
+
inline __device__ Float4_ cast_to_float(uint2 u)
|
682 |
+
{
|
683 |
+
Float4_ tmp;
|
684 |
+
tmp.x = half2_to_float2(u.x);
|
685 |
+
tmp.y = half2_to_float2(u.y);
|
686 |
+
return tmp;
|
687 |
+
}
|
688 |
+
|
689 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
690 |
+
|
691 |
+
inline __device__ Float8_ cast_to_float(uint4 u)
|
692 |
+
{
|
693 |
+
Float8_ tmp;
|
694 |
+
tmp.x = half2_to_float2(u.x);
|
695 |
+
tmp.y = half2_to_float2(u.y);
|
696 |
+
tmp.z = half2_to_float2(u.z);
|
697 |
+
tmp.w = half2_to_float2(u.w);
|
698 |
+
return tmp;
|
699 |
+
}
|
700 |
+
|
701 |
+
#endif
|
702 |
+
|
703 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
704 |
+
|
705 |
+
inline __device__ float float_from_int8(int8_t u)
|
706 |
+
{
|
707 |
+
return u;
|
708 |
+
}
|
709 |
+
|
710 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
711 |
+
|
712 |
+
inline __device__ float2 float_from_int8(int16_t u)
|
713 |
+
{
|
714 |
+
union {
|
715 |
+
int16_t int16;
|
716 |
+
int8_t int8[2];
|
717 |
+
};
|
718 |
+
int16 = u;
|
719 |
+
return make_float2(int8[0], int8[1]);
|
720 |
+
}
|
721 |
+
|
722 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
723 |
+
|
724 |
+
inline __device__ float4 float_from_int8(int32_t u)
|
725 |
+
{
|
726 |
+
union {
|
727 |
+
int32_t int32;
|
728 |
+
int8_t int8[4];
|
729 |
+
};
|
730 |
+
int32 = u;
|
731 |
+
return make_float4(int8[0], int8[1], int8[2], int8[3]);
|
732 |
+
}
|
733 |
+
|
734 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
735 |
+
|
736 |
+
// clang-format off
|
737 |
+
inline __device__ Float8_ float_from_int8(int64_t u)
|
738 |
+
{
|
739 |
+
union {
|
740 |
+
int64_t int64;
|
741 |
+
int16_t int16[4];
|
742 |
+
};
|
743 |
+
int64 = u;
|
744 |
+
return Float8_ {float_from_int8(int16[0]),
|
745 |
+
float_from_int8(int16[1]),
|
746 |
+
float_from_int8(int16[2]),
|
747 |
+
float_from_int8(int16[3])};
|
748 |
+
}
|
749 |
+
// clang-format on
|
750 |
+
|
751 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
752 |
+
|
753 |
+
inline __device__ int8_t cast_to_int8(float val)
|
754 |
+
{
|
755 |
+
union {
|
756 |
+
int8_t int8[2];
|
757 |
+
int16_t int16;
|
758 |
+
};
|
759 |
+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
|
760 |
+
return int8[0];
|
761 |
+
}
|
762 |
+
|
763 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
764 |
+
|
765 |
+
inline __device__ int32_t cast_to_int8(float4 val)
|
766 |
+
{
|
767 |
+
union {
|
768 |
+
int8_t int8[4];
|
769 |
+
int32_t int32;
|
770 |
+
};
|
771 |
+
int8[0] = cast_to_int8(val.x);
|
772 |
+
int8[1] = cast_to_int8(val.y);
|
773 |
+
int8[2] = cast_to_int8(val.z);
|
774 |
+
int8[3] = cast_to_int8(val.w);
|
775 |
+
return int32;
|
776 |
+
}
|
777 |
+
|
778 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
779 |
+
|
780 |
+
inline __device__ int64_t cast_to_int8(Float8_ val)
|
781 |
+
{
|
782 |
+
union {
|
783 |
+
int8_t int8[8];
|
784 |
+
int64_t int64;
|
785 |
+
};
|
786 |
+
int8[0] = cast_to_int8(val.x.x);
|
787 |
+
int8[1] = cast_to_int8(val.x.y);
|
788 |
+
int8[2] = cast_to_int8(val.y.x);
|
789 |
+
int8[3] = cast_to_int8(val.y.y);
|
790 |
+
int8[4] = cast_to_int8(val.z.x);
|
791 |
+
int8[5] = cast_to_int8(val.z.y);
|
792 |
+
int8[6] = cast_to_int8(val.w.x);
|
793 |
+
int8[7] = cast_to_int8(val.w.y);
|
794 |
+
return int64;
|
795 |
+
}
|
796 |
+
|
797 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
798 |
+
|
799 |
+
template<typename T>
|
800 |
+
inline __device__ __host__ T div_up(T m, T n)
|
801 |
+
{
|
802 |
+
return (m + n - 1) / n;
|
803 |
+
}
|
804 |
+
|
805 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
806 |
+
|
807 |
+
template<typename T, bool DO_CROSS_ATTENTION>
|
808 |
+
inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
|
809 |
+
int threads_per_value,
|
810 |
+
int threads_per_block)
|
811 |
+
{
|
812 |
+
// The amount of shared memory needed to store the Q*K^T values in float.
|
813 |
+
const int max_timesteps = min(params.timestep, params.memory_max_len);
|
814 |
+
size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
|
815 |
+
|
816 |
+
// The extra memory needed if we are not using floats for the final logits.
|
817 |
+
size_t logits_sz = 0;
|
818 |
+
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
|
819 |
+
if (sizeof(T) != 4) {
|
820 |
+
// TDOD
|
821 |
+
logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
|
822 |
+
div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
|
823 |
+
}
|
824 |
+
#endif
|
825 |
+
|
826 |
+
// The total size needed during softmax.
|
827 |
+
size_t softmax_sz = qk_sz + logits_sz;
|
828 |
+
|
829 |
+
// The number of partial rows to reduce in the final reduction.
|
830 |
+
int rows_per_red = threads_per_block / threads_per_value;
|
831 |
+
// The amount of storage needed to finalize the outputs.
|
832 |
+
size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
|
833 |
+
|
834 |
+
size_t transpose_rotary_size = 0;
|
835 |
+
if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
|
836 |
+
transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
|
837 |
+
}
|
838 |
+
|
839 |
+
// The max.
|
840 |
+
return max(max(softmax_sz, red_sz), transpose_rotary_size);
|
841 |
+
}
|
842 |
+
|
843 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
844 |
+
|
845 |
+
inline __device__ constexpr uint32_t shfl_mask(int threads)
|
846 |
+
{
|
847 |
+
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
|
848 |
+
}
|
849 |
+
|
850 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
851 |
+
|
852 |
+
template<
|
853 |
+
// The type of the inputs. Supported types: float and half.
|
854 |
+
typename T,
|
855 |
+
// The hidden dimension per head.
|
856 |
+
int Dh,
|
857 |
+
int Dh_MAX,
|
858 |
+
// The number of threads per key.
|
859 |
+
int THREADS_PER_KEY,
|
860 |
+
// The number of threads per value.
|
861 |
+
int THREADS_PER_VALUE,
|
862 |
+
// The number of threads in a threadblock.
|
863 |
+
int THREADS_PER_BLOCK,
|
864 |
+
bool DO_CROSS_ATTENTION>
|
865 |
+
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
|
866 |
+
{
|
867 |
+
|
868 |
+
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
|
869 |
+
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
|
870 |
+
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
|
871 |
+
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
|
872 |
+
|
873 |
+
// The size of a warp.
|
874 |
+
constexpr int WARP_SIZE = 32;
|
875 |
+
// The number of warps in a threadblock.
|
876 |
+
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
|
877 |
+
|
878 |
+
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
|
879 |
+
extern __shared__ char smem_[];
|
880 |
+
|
881 |
+
// The shared memory for the Q*K^T values and partial logits in softmax.
|
882 |
+
float* qk_smem = reinterpret_cast<float*>(smem_);
|
883 |
+
|
884 |
+
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
|
885 |
+
char* logits_smem_ = smem_;
|
886 |
+
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
|
887 |
+
if (sizeof(T) != 4) {
|
888 |
+
// TODO - change to tlength
|
889 |
+
const int max_timesteps = min(params.timestep, params.memory_max_len);
|
890 |
+
logits_smem_ +=
|
891 |
+
(DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
|
892 |
+
}
|
893 |
+
T* logits_smem = reinterpret_cast<T*>(logits_smem_);
|
894 |
+
#else
|
895 |
+
float* logits_smem = reinterpret_cast<float*>(logits_smem_);
|
896 |
+
#endif
|
897 |
+
|
898 |
+
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
|
899 |
+
T* out_smem = reinterpret_cast<T*>(smem_);
|
900 |
+
|
901 |
+
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
|
902 |
+
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
|
903 |
+
|
904 |
+
// A vector of Q or K elements for the current timestep.
|
905 |
+
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
|
906 |
+
|
907 |
+
// Use alignment for safely casting the shared buffers as Qk_vec.
|
908 |
+
// Shared memory to store Q inputs.
|
909 |
+
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
|
910 |
+
|
911 |
+
// This is one of the reasons we should have a separate kernel for cross attention
|
912 |
+
__shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
|
913 |
+
|
914 |
+
// A vector of Q or K elements for the current timestep.
|
915 |
+
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
|
916 |
+
// The number of elements per vector.
|
917 |
+
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
|
918 |
+
// Make sure the hidden size per head is a multiple of the vector size.
|
919 |
+
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
|
920 |
+
// We will use block wide reduction if needed
|
921 |
+
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
|
922 |
+
// The number of vectors per warp.
|
923 |
+
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
|
924 |
+
|
925 |
+
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
|
926 |
+
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
|
927 |
+
// tion of the thread in that chunk.
|
928 |
+
|
929 |
+
// The number of elements in a chunk of 16B (that's the x in the above formula).
|
930 |
+
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
|
931 |
+
// The number of K vectors in 16B.
|
932 |
+
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
|
933 |
+
|
934 |
+
// The batch/beam idx
|
935 |
+
const int bi = blockIdx.y;
|
936 |
+
if (params.finished != nullptr && params.finished[bi] == true) {
|
937 |
+
return;
|
938 |
+
}
|
939 |
+
// The beam idx
|
940 |
+
const int beami = bi % params.beam_width;
|
941 |
+
// The "beam-aware" batch idx
|
942 |
+
const int bbi = bi / params.beam_width;
|
943 |
+
// The head.
|
944 |
+
// const int hi = blockIdx.x;
|
945 |
+
const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
|
946 |
+
const int hi_kv = hi / params.num_heads_q_kv_ratio;
|
947 |
+
// Combine the batch and the head indices.
|
948 |
+
const int bhi = bi * params.num_heads + hi;
|
949 |
+
const int bhi_kv = bi * params.num_heads_kv + hi_kv;
|
950 |
+
// Combine the "beam-aware" batch idx and the head indices.
|
951 |
+
const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
|
952 |
+
// The thread in the block.
|
953 |
+
const int tidx = threadIdx.x;
|
954 |
+
|
955 |
+
const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
|
956 |
+
|
957 |
+
// While doing the product Q*K^T for the different keys we track the max.
|
958 |
+
float qk_max = -FLT_MAX;
|
959 |
+
|
960 |
+
float qk = 0.0F;
|
961 |
+
|
962 |
+
int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
|
963 |
+
int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
|
964 |
+
int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
|
965 |
+
|
966 |
+
const size_t bi_seq_len_offset = bi * params.memory_max_len;
|
967 |
+
|
968 |
+
// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
|
969 |
+
int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
|
970 |
+
(params.length_per_sample == nullptr) ?
|
971 |
+
params.timestep :
|
972 |
+
params.length_per_sample[bi] + params.max_prefix_prompt_length;
|
973 |
+
const int first_step = max(0, tlength + 1 - params.memory_max_len);
|
974 |
+
const int tlength_circ = tlength % params.memory_max_len;
|
975 |
+
|
976 |
+
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
|
977 |
+
const bool is_masked = tidx >= QK_VECS_PER_WARP;
|
978 |
+
|
979 |
+
// The offset in the Q and K buffer also accounts for the batch.
|
980 |
+
int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
|
981 |
+
int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
|
982 |
+
// The offset in the bias buffer.
|
983 |
+
int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
|
984 |
+
int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
|
985 |
+
|
986 |
+
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
|
987 |
+
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
|
988 |
+
|
989 |
+
// Trigger the loads from the Q and K buffers.
|
990 |
+
Qk_vec q;
|
991 |
+
zero(q);
|
992 |
+
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
|
993 |
+
if (params.int8_mode == 2) {
|
994 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
|
995 |
+
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
|
996 |
+
const auto q_scaling = params.qkv_scale_out[0];
|
997 |
+
const auto q_quant =
|
998 |
+
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
|
999 |
+
|
1000 |
+
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
|
1001 |
+
}
|
1002 |
+
else {
|
1003 |
+
q = *reinterpret_cast<const Qk_vec*>(¶ms.q[q_offset]);
|
1004 |
+
}
|
1005 |
+
}
|
1006 |
+
|
1007 |
+
Qk_vec k;
|
1008 |
+
zero(k);
|
1009 |
+
if (DO_CROSS_ATTENTION) {
|
1010 |
+
// The 16B chunk written by the thread.
|
1011 |
+
int co = tidx / QK_VECS_IN_16B;
|
1012 |
+
// The position of the thread in that 16B chunk.
|
1013 |
+
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
|
1014 |
+
|
1015 |
+
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
|
1016 |
+
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
1017 |
+
// params.timestep*QK_ELTS_IN_16B +
|
1018 |
+
tlength * QK_ELTS_IN_16B + ci;
|
1019 |
+
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
|
1020 |
+
*reinterpret_cast<const Qk_vec*>(¶ms.k_cache[offset]) :
|
1021 |
+
k;
|
1022 |
+
}
|
1023 |
+
else {
|
1024 |
+
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
|
1025 |
+
if (params.int8_mode == 2) {
|
1026 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
|
1027 |
+
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
|
1028 |
+
const auto k_scaling = params.qkv_scale_out[1];
|
1029 |
+
const auto k_quant =
|
1030 |
+
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
|
1031 |
+
|
1032 |
+
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
|
1033 |
+
}
|
1034 |
+
else {
|
1035 |
+
k = *reinterpret_cast<const Qk_vec*>(¶ms.k[k_offset]);
|
1036 |
+
}
|
1037 |
+
}
|
1038 |
+
}
|
1039 |
+
|
1040 |
+
// Trigger the loads from the Q and K bias buffers.
|
1041 |
+
Qk_vec q_bias;
|
1042 |
+
zero(q_bias);
|
1043 |
+
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
|
1044 |
+
*reinterpret_cast<const Qk_vec*>(¶ms.q_bias[q_bias_offset]) :
|
1045 |
+
q_bias;
|
1046 |
+
|
1047 |
+
Qk_vec k_bias;
|
1048 |
+
zero(k_bias);
|
1049 |
+
if (handle_kv) {
|
1050 |
+
k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
|
1051 |
+
*reinterpret_cast<const Qk_vec*>(¶ms.k_bias[k_bias_offset]) :
|
1052 |
+
k_bias;
|
1053 |
+
}
|
1054 |
+
|
1055 |
+
// Computes the Q/K values with bias.
|
1056 |
+
q = add(q, q_bias);
|
1057 |
+
if (handle_kv) {
|
1058 |
+
k = add(k, k_bias);
|
1059 |
+
}
|
1060 |
+
if (do_ia3 && !is_masked) {
|
1061 |
+
k = mul<Qk_vec, Qk_vec, Qk_vec>(
|
1062 |
+
k,
|
1063 |
+
*reinterpret_cast<const Qk_vec*>(
|
1064 |
+
¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
|
1065 |
+
}
|
1066 |
+
|
1067 |
+
// Padded len
|
1068 |
+
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
|
1069 |
+
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
|
1070 |
+
if (handle_kv) {
|
1071 |
+
if (params.rotary_cos == nullptr) {
|
1072 |
+
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
1073 |
+
} else {
|
1074 |
+
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len,
|
1075 |
+
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
1076 |
+
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
1077 |
+
}
|
1078 |
+
}
|
1079 |
+
else {
|
1080 |
+
if (params.rotary_cos == nullptr) {
|
1081 |
+
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
1082 |
+
} else {
|
1083 |
+
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len,
|
1084 |
+
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
1085 |
+
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
1086 |
+
}
|
1087 |
+
}
|
1088 |
+
}
|
1089 |
+
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
|
1090 |
+
const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
1091 |
+
|
1092 |
+
T* q_smem = reinterpret_cast<T*>(smem_);
|
1093 |
+
T* k_smem = q_smem + params.rotary_embedding_dim;
|
1094 |
+
|
1095 |
+
const int half_rotary_dim = params.rotary_embedding_dim / 2;
|
1096 |
+
const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
|
1097 |
+
const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
|
1098 |
+
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
|
1099 |
+
|
1100 |
+
assert(half_rotary_dim % QK_VEC_SIZE == 0);
|
1101 |
+
|
1102 |
+
if (do_rotary) {
|
1103 |
+
*reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
|
1104 |
+
|
1105 |
+
if (handle_kv) {
|
1106 |
+
*reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
|
1107 |
+
}
|
1108 |
+
}
|
1109 |
+
|
1110 |
+
__syncthreads();
|
1111 |
+
|
1112 |
+
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
|
1113 |
+
constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
|
1114 |
+
if (do_rotary) {
|
1115 |
+
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
1116 |
+
|
1117 |
+
if (handle_kv) {
|
1118 |
+
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
1119 |
+
|
1120 |
+
if (params.rotary_cos == nullptr) {
|
1121 |
+
mmha::apply_rotary_embedding(
|
1122 |
+
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
1123 |
+
} else {
|
1124 |
+
mmha::apply_rotary_embedding(
|
1125 |
+
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len,
|
1126 |
+
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
1127 |
+
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
1128 |
+
}
|
1129 |
+
|
1130 |
+
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
1131 |
+
}
|
1132 |
+
else {
|
1133 |
+
if (params.rotary_cos == nullptr) {
|
1134 |
+
mmha::apply_rotary_embedding(
|
1135 |
+
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
|
1136 |
+
} else {
|
1137 |
+
mmha::apply_rotary_embedding(
|
1138 |
+
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength,
|
1139 |
+
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
1140 |
+
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
1141 |
+
}
|
1142 |
+
}
|
1143 |
+
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
1144 |
+
}
|
1145 |
+
|
1146 |
+
__syncthreads();
|
1147 |
+
|
1148 |
+
if (do_rotary) {
|
1149 |
+
q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
|
1150 |
+
if (handle_kv) {
|
1151 |
+
k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
|
1152 |
+
}
|
1153 |
+
}
|
1154 |
+
|
1155 |
+
__syncthreads();
|
1156 |
+
}
|
1157 |
+
|
1158 |
+
if (!is_masked) {
|
1159 |
+
// Store the Q values to shared memory.
|
1160 |
+
*reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
|
1161 |
+
|
1162 |
+
// Store Dh values of k_bias into smem, since will need to add later
|
1163 |
+
// if params.timestep == 0
|
1164 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1165 |
+
*reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
|
1166 |
+
}
|
1167 |
+
|
1168 |
+
// Write the K values to the global memory cache.
|
1169 |
+
//
|
1170 |
+
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
1171 |
+
// system. We designed it this way as it allows much better memory loads (and there are many
|
1172 |
+
// more loads) + the stores are really "write and forget" since we won't need the ack before
|
1173 |
+
// the end of the kernel. There's plenty of time for the transactions to complete.
|
1174 |
+
|
1175 |
+
// The 16B chunk written by the thread.
|
1176 |
+
int co = tidx / QK_VECS_IN_16B;
|
1177 |
+
// The position of the thread in that 16B chunk.
|
1178 |
+
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
|
1179 |
+
|
1180 |
+
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
|
1181 |
+
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
1182 |
+
// params.timestep*QK_ELTS_IN_16B +
|
1183 |
+
tlength_circ * QK_ELTS_IN_16B + ci;
|
1184 |
+
|
1185 |
+
if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
|
1186 |
+
// Trigger the stores to global memory.
|
1187 |
+
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
|
1188 |
+
*reinterpret_cast<Qk_vec*>(¶ms.k_cache[offset]) = k;
|
1189 |
+
}
|
1190 |
+
}
|
1191 |
+
|
1192 |
+
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
|
1193 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
1194 |
+
using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
|
1195 |
+
#else
|
1196 |
+
using Qk_vec_acum = Qk_vec;
|
1197 |
+
#endif
|
1198 |
+
qk = dot<Qk_vec_acum, Qk_vec>(q, k);
|
1199 |
+
if (QK_VECS_PER_WARP <= WARP_SIZE) {
|
1200 |
+
#pragma unroll
|
1201 |
+
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
|
1202 |
+
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
|
1203 |
+
}
|
1204 |
+
}
|
1205 |
+
}
|
1206 |
+
|
1207 |
+
if (QK_VECS_PER_WARP > WARP_SIZE) {
|
1208 |
+
constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
|
1209 |
+
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
|
1210 |
+
}
|
1211 |
+
|
1212 |
+
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
|
1213 |
+
if (tidx == 0) {
|
1214 |
+
// Normalize qk.
|
1215 |
+
qk *= params.inv_sqrt_dh;
|
1216 |
+
if (params.relative_attention_bias != nullptr) {
|
1217 |
+
qk = add(qk,
|
1218 |
+
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
1219 |
+
* params.relative_attention_bias_stride
|
1220 |
+
+ (tlength - padd_len) * params.relative_attention_bias_stride
|
1221 |
+
+ (tlength - padd_len)]);
|
1222 |
+
}
|
1223 |
+
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
|
1224 |
+
|
1225 |
+
qk_max = qk;
|
1226 |
+
qk_smem[tlength - first_step] = qk;
|
1227 |
+
// qk_smem[params.timestep] = qk;
|
1228 |
+
}
|
1229 |
+
|
1230 |
+
// Make sure the data is in shared memory.
|
1231 |
+
__syncthreads();
|
1232 |
+
|
1233 |
+
// The type of queries and keys for the math in the Q*K^T product.
|
1234 |
+
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
|
1235 |
+
// The number of elements per vector.
|
1236 |
+
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
|
1237 |
+
// Make sure the hidden size per head is a multiple of the vector size.
|
1238 |
+
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
|
1239 |
+
// The number of elements per thread.
|
1240 |
+
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
|
1241 |
+
// The number of vectors per thread.
|
1242 |
+
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
|
1243 |
+
|
1244 |
+
// The position the first key loaded by each thread from the cache buffer (for this B * H).
|
1245 |
+
int ko = tidx / THREADS_PER_KEY;
|
1246 |
+
// The position of the thread in the chunk of keys.
|
1247 |
+
int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
|
1248 |
+
|
1249 |
+
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
|
1250 |
+
|
1251 |
+
// Load the Q values from shared memory. The values are reused during the loop on K.
|
1252 |
+
K_vec q_vec[K_VECS_PER_THREAD];
|
1253 |
+
#pragma unroll
|
1254 |
+
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
1255 |
+
q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
|
1256 |
+
}
|
1257 |
+
|
1258 |
+
K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
|
1259 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1260 |
+
#pragma unroll
|
1261 |
+
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
1262 |
+
k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
|
1263 |
+
}
|
1264 |
+
}
|
1265 |
+
|
1266 |
+
// The number of timesteps loaded per iteration.
|
1267 |
+
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
|
1268 |
+
// The number of keys per warp.
|
1269 |
+
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
|
1270 |
+
|
1271 |
+
// The base pointer for the key in the cache buffer.
|
1272 |
+
T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
|
1273 |
+
// Base pointer for the beam's batch, before offsetting with indirection buffer
|
1274 |
+
T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki];
|
1275 |
+
|
1276 |
+
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
1277 |
+
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
|
1278 |
+
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
|
1279 |
+
|
1280 |
+
// prefix prompt length if has
|
1281 |
+
const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
|
1282 |
+
|
1283 |
+
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
1284 |
+
const bool has_beams = params.cache_indir != nullptr;
|
1285 |
+
const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr;
|
1286 |
+
|
1287 |
+
for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
|
1288 |
+
const int ti_circ = ti % params.memory_max_len;
|
1289 |
+
|
1290 |
+
// The keys loaded from the key cache.
|
1291 |
+
K_vec k[K_VECS_PER_THREAD];
|
1292 |
+
K_vec k_vec_zero;
|
1293 |
+
zero(k_vec_zero);
|
1294 |
+
#pragma unroll
|
1295 |
+
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
1296 |
+
int jj = ii * params.memory_max_len + ti_circ;
|
1297 |
+
// if( ti < params.timestep ) {
|
1298 |
+
const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
|
1299 |
+
if (ti < tlength) {
|
1300 |
+
if (!within_bounds) {
|
1301 |
+
k[ii] = k_vec_zero;
|
1302 |
+
}
|
1303 |
+
else {
|
1304 |
+
if (has_beams) {
|
1305 |
+
const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
|
1306 |
+
k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
|
1307 |
+
}
|
1308 |
+
else {
|
1309 |
+
k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
|
1310 |
+
}
|
1311 |
+
}
|
1312 |
+
// add bias and update k_cache
|
1313 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1314 |
+
k[ii] = add(k[ii], k_bias_vec[ii]);
|
1315 |
+
|
1316 |
+
if (do_ia3) {
|
1317 |
+
k[ii] = mul<K_vec, K_vec, K_vec>(
|
1318 |
+
k[ii],
|
1319 |
+
*reinterpret_cast<const K_vec*>(
|
1320 |
+
¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
|
1321 |
+
+ ii * THREADS_PER_KEY * K_VEC_SIZE]));
|
1322 |
+
}
|
1323 |
+
|
1324 |
+
if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
|
1325 |
+
*reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
|
1326 |
+
}
|
1327 |
+
}
|
1328 |
+
}
|
1329 |
+
}
|
1330 |
+
|
1331 |
+
// Perform the dot product and normalize qk.
|
1332 |
+
//
|
1333 |
+
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
1334 |
+
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
|
1335 |
+
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
|
1336 |
+
|
1337 |
+
// Store the product to shared memory. There's one qk value per timestep. Update the max.
|
1338 |
+
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
|
1339 |
+
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
|
1340 |
+
if (params.relative_attention_bias != nullptr) {
|
1341 |
+
qk = add(qk,
|
1342 |
+
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
1343 |
+
* params.relative_attention_bias_stride
|
1344 |
+
+ tlength * params.relative_attention_bias_stride + ti]);
|
1345 |
+
}
|
1346 |
+
if (params.linear_bias_slopes != nullptr) {
|
1347 |
+
// Apply the linear position bias: (ki - qi) * slope[hi].
|
1348 |
+
// The padding token locates between the input context and the generated tokens.
|
1349 |
+
// We need to remove the number of padding tokens in the distance computation.
|
1350 |
+
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
|
1351 |
+
// token: i i i i p p p o o o where i=input, p=pad, o=output.
|
1352 |
+
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
|
1353 |
+
int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
|
1354 |
+
float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
|
1355 |
+
|
1356 |
+
qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
|
1357 |
+
}
|
1358 |
+
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
|
1359 |
+
qk_smem[ti - first_step] = qk;
|
1360 |
+
}
|
1361 |
+
}
|
1362 |
+
|
1363 |
+
// Perform the final reduction to compute the max inside each warp.
|
1364 |
+
//
|
1365 |
+
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
|
1366 |
+
// group so it's not needed to run the reduction inside the group (again).
|
1367 |
+
#pragma unroll
|
1368 |
+
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
|
1369 |
+
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
1370 |
+
}
|
1371 |
+
|
1372 |
+
// Decompose the thread index into warp and lane.
|
1373 |
+
const int warp = tidx / WARP_SIZE;
|
1374 |
+
const int lane = tidx % WARP_SIZE;
|
1375 |
+
|
1376 |
+
// The warp leader writes the max to shared memory.
|
1377 |
+
if (lane == 0) {
|
1378 |
+
red_smem[warp] = qk_max;
|
1379 |
+
}
|
1380 |
+
|
1381 |
+
// Make sure the products are in shared memory.
|
1382 |
+
__syncthreads();
|
1383 |
+
|
1384 |
+
// The warps finalize the reduction.
|
1385 |
+
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
|
1386 |
+
#pragma unroll
|
1387 |
+
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
1388 |
+
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
1389 |
+
}
|
1390 |
+
|
1391 |
+
// Broadcast to all the threads in the warp.
|
1392 |
+
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
1393 |
+
|
1394 |
+
// Compute the logits and start the sum.
|
1395 |
+
float sum = 0.f;
|
1396 |
+
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
|
1397 |
+
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
|
1398 |
+
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
|
1399 |
+
float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
|
1400 |
+
sum += logit;
|
1401 |
+
qk_smem[ti - first_step] = logit;
|
1402 |
+
}
|
1403 |
+
|
1404 |
+
// Compute the sum.
|
1405 |
+
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
|
1406 |
+
|
1407 |
+
// Normalize the logits.
|
1408 |
+
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
|
1409 |
+
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
|
1410 |
+
const size_t cross_attention_out_offset =
|
1411 |
+
params.is_return_cross_attentions ?
|
1412 |
+
bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
|
1413 |
+
0;
|
1414 |
+
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
|
1415 |
+
float logit = qk_smem[ti - first_step] * inv_sum;
|
1416 |
+
if (params.is_return_cross_attentions) {
|
1417 |
+
params.cross_attention_out[cross_attention_out_offset + ti] = logit;
|
1418 |
+
}
|
1419 |
+
convert_from_float(logits_smem[ti - first_step], logit);
|
1420 |
+
}
|
1421 |
+
|
1422 |
+
// Put Values part below so we leverage __syncthreads
|
1423 |
+
// from the previous step
|
1424 |
+
|
1425 |
+
// The number of elements per vector.
|
1426 |
+
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
|
1427 |
+
// A vector of V elements for the current timestep.
|
1428 |
+
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
|
1429 |
+
|
1430 |
+
// The value computed by this thread.
|
1431 |
+
int vo = tidx / THREADS_PER_VALUE;
|
1432 |
+
// The hidden dimensions computed by this particular thread.
|
1433 |
+
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
|
1434 |
+
|
1435 |
+
// The base pointer for the value in the cache buffer.
|
1436 |
+
T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
|
1437 |
+
// Base pointer for the beam's batch, before offsetting with indirection buffer
|
1438 |
+
T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi];
|
1439 |
+
|
1440 |
+
// The number of values processed per iteration of the loop.
|
1441 |
+
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
|
1442 |
+
|
1443 |
+
// One group of threads computes the product(s) for the current timestep.
|
1444 |
+
V_vec v_bias;
|
1445 |
+
zero(v_bias);
|
1446 |
+
// if( vo == params.timestep % V_PER_ITER ) {
|
1447 |
+
if (Dh == Dh_MAX || vi < Dh) {
|
1448 |
+
if (handle_kv) {
|
1449 |
+
if (vo == tlength % V_PER_ITER) {
|
1450 |
+
// Trigger the loads from the V bias buffer.
|
1451 |
+
if (params.v_bias != nullptr) {
|
1452 |
+
v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi_kv * Dh + vi]);
|
1453 |
+
}
|
1454 |
+
if (DO_CROSS_ATTENTION) {
|
1455 |
+
*reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
|
1456 |
+
}
|
1457 |
+
}
|
1458 |
+
}
|
1459 |
+
}
|
1460 |
+
|
1461 |
+
// From previous, before values, step
|
1462 |
+
// Also make sure the logits are in shared memory.
|
1463 |
+
__syncthreads();
|
1464 |
+
|
1465 |
+
// Values continued
|
1466 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
1467 |
+
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
|
1468 |
+
#else
|
1469 |
+
using V_vec_acum = V_vec;
|
1470 |
+
#endif
|
1471 |
+
// The partial outputs computed by each thread.
|
1472 |
+
V_vec_acum out;
|
1473 |
+
zero(out);
|
1474 |
+
|
1475 |
+
// Loop over the timesteps to compute the partial outputs.
|
1476 |
+
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
|
1477 |
+
if (Dh == Dh_MAX || vi < Dh) {
|
1478 |
+
for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
|
1479 |
+
const int ti_circ = ti % params.memory_max_len;
|
1480 |
+
|
1481 |
+
// Fetch offset based on cache_indir when beam sampling
|
1482 |
+
const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
|
1483 |
+
const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
|
1484 |
+
// Load the values from the cache.
|
1485 |
+
V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
|
1486 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1487 |
+
v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
|
1488 |
+
if (do_ia3) {
|
1489 |
+
v = mul<V_vec, V_vec, V_vec>(
|
1490 |
+
v,
|
1491 |
+
*reinterpret_cast<const V_vec*>(
|
1492 |
+
¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
|
1493 |
+
}
|
1494 |
+
*reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
|
1495 |
+
}
|
1496 |
+
// Load the logits from shared memory.
|
1497 |
+
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
1498 |
+
float logit = logits_smem[ti - first_step];
|
1499 |
+
out = fma(logit, cast_to_float(v), out);
|
1500 |
+
#else
|
1501 |
+
T logit = logits_smem[ti - first_step];
|
1502 |
+
|
1503 |
+
// Update the partial sums.
|
1504 |
+
out = fma(logit, v, out);
|
1505 |
+
#endif
|
1506 |
+
}
|
1507 |
+
}
|
1508 |
+
|
1509 |
+
// One group of threads computes the product(s) for the current timestep.
|
1510 |
+
// if( vo == params.timestep % V_PER_ITER ) {
|
1511 |
+
if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
|
1512 |
+
|
1513 |
+
V_vec v;
|
1514 |
+
if (DO_CROSS_ATTENTION) {
|
1515 |
+
v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
|
1516 |
+
}
|
1517 |
+
else {
|
1518 |
+
// Trigger the loads from the V buffer.
|
1519 |
+
const auto v_offset = v_base_offset + vi;
|
1520 |
+
if (params.int8_mode == 2) {
|
1521 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
|
1522 |
+
using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
|
1523 |
+
const auto v_scaling = params.qkv_scale_out[2];
|
1524 |
+
const auto v_quant =
|
1525 |
+
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
|
1526 |
+
|
1527 |
+
convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
|
1528 |
+
}
|
1529 |
+
else {
|
1530 |
+
v = *reinterpret_cast<const V_vec*>(¶ms.v[v_offset]);
|
1531 |
+
}
|
1532 |
+
// Trigger the loads from the V bias buffer.
|
1533 |
+
// V_vec v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi*Dh + vi]);
|
1534 |
+
}
|
1535 |
+
|
1536 |
+
// Compute the V values with bias.
|
1537 |
+
if (handle_kv) {
|
1538 |
+
v = add(v, v_bias);
|
1539 |
+
|
1540 |
+
if (do_ia3) {
|
1541 |
+
v = mul<V_vec, V_vec, V_vec>(
|
1542 |
+
v,
|
1543 |
+
*reinterpret_cast<const V_vec*>(
|
1544 |
+
¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
|
1545 |
+
}
|
1546 |
+
|
1547 |
+
// Store the values with bias back to global memory in the cache for V.
|
1548 |
+
if (hi % params.num_heads_q_kv_ratio == 0) {
|
1549 |
+
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
|
1550 |
+
*reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
|
1551 |
+
}
|
1552 |
+
}
|
1553 |
+
|
1554 |
+
// Initialize the output value with the current timestep.
|
1555 |
+
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
1556 |
+
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
|
1557 |
+
out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
|
1558 |
+
#else
|
1559 |
+
// out = fma(logits_smem[params.timestep], v, out);
|
1560 |
+
out = fma(logits_smem[tlength - first_step], v, out);
|
1561 |
+
#endif
|
1562 |
+
}
|
1563 |
+
|
1564 |
+
// Make sure we can start writing to shared memory.
|
1565 |
+
__syncthreads();
|
1566 |
+
|
1567 |
+
// Run the final reduction amongst the different groups computing different partial outputs.
|
1568 |
+
if (Dh == Dh_MAX || vi < Dh) {
|
1569 |
+
#pragma unroll
|
1570 |
+
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
|
1571 |
+
|
1572 |
+
// The midpoint in the number of active groups.
|
1573 |
+
int midpoint = active_groups / 2;
|
1574 |
+
|
1575 |
+
// The upper part of active threads store to shared memory.
|
1576 |
+
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
|
1577 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
1578 |
+
convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
|
1579 |
+
#else
|
1580 |
+
*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
|
1581 |
+
#endif
|
1582 |
+
}
|
1583 |
+
__syncthreads();
|
1584 |
+
|
1585 |
+
// The bottom warps update their values.
|
1586 |
+
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
|
1587 |
+
out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
|
1588 |
+
}
|
1589 |
+
__syncthreads();
|
1590 |
+
}
|
1591 |
+
}
|
1592 |
+
|
1593 |
+
// Output the final values.
|
1594 |
+
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
|
1595 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
1596 |
+
if (params.int8_mode == 2) {
|
1597 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
|
1598 |
+
out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
|
1599 |
+
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
|
1600 |
+
cast_to_int8(out);
|
1601 |
+
}
|
1602 |
+
else {
|
1603 |
+
convert_from_float(*reinterpret_cast<V_vec*>(¶ms.out[bhi * Dh + vi]), out);
|
1604 |
+
}
|
1605 |
+
#else
|
1606 |
+
// TODO: support int8_mode?
|
1607 |
+
*reinterpret_cast<V_vec*>(¶ms.out[bhi * Dh + vi]) = out;
|
1608 |
+
#endif
|
1609 |
+
}
|
1610 |
+
}
|
1611 |
+
|
1612 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1613 |
+
|
1614 |
+
} // namespace mmha
|
1615 |
+
|
1616 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1617 |
+
|
1618 |
+
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
|
1619 |
+
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);
|
decoder_masked_multihead_attention_utils.h
ADDED
@@ -0,0 +1,2017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include "cuda_bf16_fallbacks.cuh"
|
23 |
+
#include <stdint.h>
|
24 |
+
|
25 |
+
using namespace fastertransformer;
|
26 |
+
|
27 |
+
namespace mmha {
|
28 |
+
|
29 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
30 |
+
|
31 |
+
struct Float8_ {
|
32 |
+
float2 x;
|
33 |
+
float2 y;
|
34 |
+
float2 z;
|
35 |
+
float2 w;
|
36 |
+
};
|
37 |
+
|
38 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
39 |
+
|
40 |
+
struct Float4_ {
|
41 |
+
float2 x;
|
42 |
+
float2 y;
|
43 |
+
};
|
44 |
+
|
45 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
46 |
+
|
47 |
+
#ifdef ENABLE_BF16
|
48 |
+
struct bf16_4_t {
|
49 |
+
__nv_bfloat162 x;
|
50 |
+
__nv_bfloat162 y;
|
51 |
+
};
|
52 |
+
|
53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
54 |
+
|
55 |
+
struct bf16_8_t {
|
56 |
+
__nv_bfloat162 x;
|
57 |
+
__nv_bfloat162 y;
|
58 |
+
__nv_bfloat162 z;
|
59 |
+
__nv_bfloat162 w;
|
60 |
+
};
|
61 |
+
#endif
|
62 |
+
|
63 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
64 |
+
|
65 |
+
template<typename T>
|
66 |
+
struct num_elems;
|
67 |
+
template<>
|
68 |
+
struct num_elems<float> {
|
69 |
+
static constexpr int value = 1;
|
70 |
+
};
|
71 |
+
template<>
|
72 |
+
struct num_elems<float2> {
|
73 |
+
static constexpr int value = 2;
|
74 |
+
};
|
75 |
+
template<>
|
76 |
+
struct num_elems<float4> {
|
77 |
+
static constexpr int value = 4;
|
78 |
+
};
|
79 |
+
template<>
|
80 |
+
struct num_elems<Float4_> {
|
81 |
+
static constexpr int value = 4;
|
82 |
+
};
|
83 |
+
template<>
|
84 |
+
struct num_elems<Float8_> {
|
85 |
+
static constexpr int value = 8;
|
86 |
+
};
|
87 |
+
|
88 |
+
template<>
|
89 |
+
struct num_elems<uint32_t> {
|
90 |
+
static constexpr int value = 2;
|
91 |
+
};
|
92 |
+
template<>
|
93 |
+
struct num_elems<uint2> {
|
94 |
+
static constexpr int value = 4;
|
95 |
+
};
|
96 |
+
template<>
|
97 |
+
struct num_elems<uint4> {
|
98 |
+
static constexpr int value = 8;
|
99 |
+
};
|
100 |
+
|
101 |
+
#ifdef ENABLE_BF16
|
102 |
+
template<>
|
103 |
+
struct num_elems<__nv_bfloat162> {
|
104 |
+
static constexpr int value = 2;
|
105 |
+
};
|
106 |
+
template<>
|
107 |
+
struct num_elems<bf16_4_t> {
|
108 |
+
static constexpr int value = 4;
|
109 |
+
};
|
110 |
+
template<>
|
111 |
+
struct num_elems<bf16_8_t> {
|
112 |
+
static constexpr int value = 8;
|
113 |
+
};
|
114 |
+
#endif
|
115 |
+
|
116 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
117 |
+
|
118 |
+
template<typename T, int N>
|
119 |
+
struct packed_type;
|
120 |
+
template<typename T>
|
121 |
+
struct packed_type<T, 1> {
|
122 |
+
using type = T;
|
123 |
+
};
|
124 |
+
template<>
|
125 |
+
struct packed_type<int8_t, 2> {
|
126 |
+
using type = int16_t;
|
127 |
+
};
|
128 |
+
template<>
|
129 |
+
struct packed_type<int8_t, 4> {
|
130 |
+
using type = int32_t;
|
131 |
+
};
|
132 |
+
template<>
|
133 |
+
struct packed_type<int8_t, 8> {
|
134 |
+
using type = int64_t;
|
135 |
+
};
|
136 |
+
|
137 |
+
template<>
|
138 |
+
struct packed_type<float, 2> {
|
139 |
+
using type = float2;
|
140 |
+
};
|
141 |
+
template<>
|
142 |
+
struct packed_type<float, 4> {
|
143 |
+
using type = float4;
|
144 |
+
};
|
145 |
+
template<>
|
146 |
+
struct packed_type<float, 8> {
|
147 |
+
using type = Float8_;
|
148 |
+
};
|
149 |
+
|
150 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
151 |
+
|
152 |
+
inline __device__ float add(float a, float b)
|
153 |
+
{
|
154 |
+
return a + b;
|
155 |
+
}
|
156 |
+
|
157 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
158 |
+
|
159 |
+
inline __device__ float2 add(float2 a, float2 b)
|
160 |
+
{
|
161 |
+
float2 c;
|
162 |
+
c.x = add(a.x, b.x);
|
163 |
+
c.y = add(a.y, b.y);
|
164 |
+
return c;
|
165 |
+
}
|
166 |
+
|
167 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
168 |
+
|
169 |
+
inline __device__ float4 add(float4 a, float4 b)
|
170 |
+
{
|
171 |
+
float4 c;
|
172 |
+
c.x = add(a.x, b.x);
|
173 |
+
c.y = add(a.y, b.y);
|
174 |
+
c.z = add(a.z, b.z);
|
175 |
+
c.w = add(a.w, b.w);
|
176 |
+
return c;
|
177 |
+
}
|
178 |
+
|
179 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
180 |
+
|
181 |
+
#ifdef ENABLE_BF16
|
182 |
+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
|
183 |
+
{
|
184 |
+
return a + b;
|
185 |
+
}
|
186 |
+
|
187 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
188 |
+
|
189 |
+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
|
190 |
+
{
|
191 |
+
return bf16hadd2(a, b);
|
192 |
+
}
|
193 |
+
|
194 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
195 |
+
|
196 |
+
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
|
197 |
+
{
|
198 |
+
bf16_4_t c;
|
199 |
+
c.x = add(a.x, b.x);
|
200 |
+
c.y = add(a.y, b.y);
|
201 |
+
return c;
|
202 |
+
}
|
203 |
+
|
204 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
205 |
+
|
206 |
+
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
|
207 |
+
{
|
208 |
+
bf16_8_t c;
|
209 |
+
c.x = add(a.x, b.x);
|
210 |
+
c.y = add(a.y, b.y);
|
211 |
+
c.z = add(a.z, b.z);
|
212 |
+
c.w = add(a.w, b.w);
|
213 |
+
return c;
|
214 |
+
}
|
215 |
+
#endif // ENABLE_BF16
|
216 |
+
|
217 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
218 |
+
|
219 |
+
inline __device__ uint16_t add(uint16_t a, uint16_t b)
|
220 |
+
{
|
221 |
+
uint16_t c;
|
222 |
+
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
223 |
+
return c;
|
224 |
+
}
|
225 |
+
|
226 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
227 |
+
|
228 |
+
inline __device__ uint32_t add(uint32_t a, uint32_t b)
|
229 |
+
{
|
230 |
+
uint32_t c;
|
231 |
+
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
232 |
+
return c;
|
233 |
+
}
|
234 |
+
|
235 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
236 |
+
|
237 |
+
inline __device__ uint2 add(uint2 a, uint2 b)
|
238 |
+
{
|
239 |
+
uint2 c;
|
240 |
+
c.x = add(a.x, b.x);
|
241 |
+
c.y = add(a.y, b.y);
|
242 |
+
return c;
|
243 |
+
}
|
244 |
+
|
245 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
246 |
+
|
247 |
+
inline __device__ uint4 add(uint4 a, uint4 b)
|
248 |
+
{
|
249 |
+
uint4 c;
|
250 |
+
c.x = add(a.x, b.x);
|
251 |
+
c.y = add(a.y, b.y);
|
252 |
+
c.z = add(a.z, b.z);
|
253 |
+
c.w = add(a.w, b.w);
|
254 |
+
return c;
|
255 |
+
}
|
256 |
+
|
257 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
258 |
+
|
259 |
+
inline __device__ uint16_t float_to_half(float f)
|
260 |
+
{
|
261 |
+
union {
|
262 |
+
uint32_t u32;
|
263 |
+
uint16_t u16[2];
|
264 |
+
} tmp;
|
265 |
+
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
|
266 |
+
float zero = 0.f;
|
267 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
|
268 |
+
#else
|
269 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
270 |
+
#endif
|
271 |
+
return tmp.u16[0];
|
272 |
+
}
|
273 |
+
|
274 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
275 |
+
|
276 |
+
inline __device__ uint32_t float2_to_half2(float2 f)
|
277 |
+
{
|
278 |
+
union {
|
279 |
+
uint32_t u32;
|
280 |
+
uint16_t u16[2];
|
281 |
+
} tmp;
|
282 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
283 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
284 |
+
#else
|
285 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
286 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
287 |
+
#endif
|
288 |
+
return tmp.u32;
|
289 |
+
}
|
290 |
+
|
291 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
292 |
+
|
293 |
+
inline __device__ float half_to_float(uint16_t h)
|
294 |
+
{
|
295 |
+
float f;
|
296 |
+
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
297 |
+
return f;
|
298 |
+
}
|
299 |
+
|
300 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
301 |
+
|
302 |
+
inline __device__ float2 half2_to_float2(uint32_t v)
|
303 |
+
{
|
304 |
+
uint16_t lo, hi;
|
305 |
+
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
306 |
+
return make_float2(half_to_float(lo), half_to_float(hi));
|
307 |
+
}
|
308 |
+
|
309 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
310 |
+
|
311 |
+
inline __device__ float add(float a, uint16_t b)
|
312 |
+
{
|
313 |
+
return a + half_to_float(b);
|
314 |
+
}
|
315 |
+
|
316 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
317 |
+
|
318 |
+
#ifdef ENABLE_BF16
|
319 |
+
inline __device__ float add(float a, __nv_bfloat16 b)
|
320 |
+
{
|
321 |
+
return a + __bfloat162float(b);
|
322 |
+
}
|
323 |
+
#endif
|
324 |
+
|
325 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
326 |
+
|
327 |
+
inline __device__ float2 add(uint32_t a, float2 fb)
|
328 |
+
{
|
329 |
+
float2 fa = half2_to_float2(a);
|
330 |
+
return add(fa, fb);
|
331 |
+
}
|
332 |
+
|
333 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
334 |
+
|
335 |
+
inline __device__ Float4_ add(uint2 a, Float4_ fb)
|
336 |
+
{
|
337 |
+
Float4_ fc;
|
338 |
+
fc.x = add(a.x, fb.x);
|
339 |
+
fc.y = add(a.y, fb.y);
|
340 |
+
return fc;
|
341 |
+
}
|
342 |
+
|
343 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
344 |
+
|
345 |
+
inline __device__ Float8_ add(uint4 a, Float8_ fb)
|
346 |
+
{
|
347 |
+
Float8_ fc;
|
348 |
+
fc.x = add(a.x, fb.x);
|
349 |
+
fc.y = add(a.y, fb.y);
|
350 |
+
fc.z = add(a.z, fb.z);
|
351 |
+
fc.w = add(a.w, fb.w);
|
352 |
+
return fc;
|
353 |
+
}
|
354 |
+
|
355 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
356 |
+
|
357 |
+
inline __device__ uint32_t h0_h0(uint16_t a)
|
358 |
+
{
|
359 |
+
uint32_t b;
|
360 |
+
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
361 |
+
return b;
|
362 |
+
}
|
363 |
+
|
364 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
365 |
+
|
366 |
+
inline __device__ float fma(float a, float b, float c)
|
367 |
+
{
|
368 |
+
return a * b + c;
|
369 |
+
}
|
370 |
+
|
371 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
372 |
+
|
373 |
+
inline __device__ float2 fma(float2 a, float2 b, float2 c)
|
374 |
+
{
|
375 |
+
float2 d;
|
376 |
+
d.x = fma(a.x, b.x, c.x);
|
377 |
+
d.y = fma(a.y, b.y, c.y);
|
378 |
+
return d;
|
379 |
+
}
|
380 |
+
|
381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
382 |
+
|
383 |
+
inline __device__ float2 fma(float a, float2 b, float2 c)
|
384 |
+
{
|
385 |
+
float2 d;
|
386 |
+
d.x = fma(a, b.x, c.x);
|
387 |
+
d.y = fma(a, b.y, c.y);
|
388 |
+
return d;
|
389 |
+
}
|
390 |
+
|
391 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
392 |
+
|
393 |
+
inline __device__ float4 fma(float4 a, float4 b, float4 c)
|
394 |
+
{
|
395 |
+
float4 d;
|
396 |
+
d.x = fma(a.x, b.x, c.x);
|
397 |
+
d.y = fma(a.y, b.y, c.y);
|
398 |
+
d.z = fma(a.z, b.z, c.z);
|
399 |
+
d.w = fma(a.w, b.w, c.w);
|
400 |
+
return d;
|
401 |
+
}
|
402 |
+
|
403 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
404 |
+
|
405 |
+
inline __device__ float4 fma(float a, float4 b, float4 c)
|
406 |
+
{
|
407 |
+
float4 d;
|
408 |
+
d.x = fma(a, b.x, c.x);
|
409 |
+
d.y = fma(a, b.y, c.y);
|
410 |
+
d.z = fma(a, b.z, c.z);
|
411 |
+
d.w = fma(a, b.w, c.w);
|
412 |
+
return d;
|
413 |
+
}
|
414 |
+
|
415 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
416 |
+
|
417 |
+
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
|
418 |
+
{
|
419 |
+
Float4_ d;
|
420 |
+
d.x = fma(a, b.x, c.x);
|
421 |
+
d.y = fma(a, b.y, c.y);
|
422 |
+
return d;
|
423 |
+
}
|
424 |
+
|
425 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
426 |
+
|
427 |
+
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
|
428 |
+
{
|
429 |
+
Float8_ d;
|
430 |
+
d.x = fma(a, b.x, c.x);
|
431 |
+
d.y = fma(a, b.y, c.y);
|
432 |
+
d.z = fma(a, b.z, c.z);
|
433 |
+
d.w = fma(a, b.w, c.w);
|
434 |
+
return d;
|
435 |
+
}
|
436 |
+
|
437 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
438 |
+
|
439 |
+
#ifdef ENABLE_BF16
|
440 |
+
inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
|
441 |
+
{
|
442 |
+
float2 fa = bf1622float2(a);
|
443 |
+
return add(fa, fb);
|
444 |
+
}
|
445 |
+
|
446 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
447 |
+
|
448 |
+
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
|
449 |
+
{
|
450 |
+
Float4_ fc;
|
451 |
+
fc.x = add(a.x, fb.x);
|
452 |
+
fc.y = add(a.y, fb.y);
|
453 |
+
return fc;
|
454 |
+
}
|
455 |
+
|
456 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
457 |
+
|
458 |
+
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
|
459 |
+
{
|
460 |
+
Float8_ fc;
|
461 |
+
fc.x = add(a.x, fb.x);
|
462 |
+
fc.y = add(a.y, fb.y);
|
463 |
+
fc.z = add(a.z, fb.z);
|
464 |
+
fc.w = add(a.w, fb.w);
|
465 |
+
return fc;
|
466 |
+
}
|
467 |
+
#endif // ENABLE_BF16
|
468 |
+
|
469 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
470 |
+
|
471 |
+
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
|
472 |
+
{
|
473 |
+
uint32_t d;
|
474 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
475 |
+
return d;
|
476 |
+
}
|
477 |
+
|
478 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
479 |
+
|
480 |
+
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
|
481 |
+
{
|
482 |
+
return fma(h0_h0(a), b, c);
|
483 |
+
}
|
484 |
+
|
485 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
486 |
+
|
487 |
+
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
|
488 |
+
{
|
489 |
+
uint2 d;
|
490 |
+
d.x = fma(a.x, b.x, c.x);
|
491 |
+
d.y = fma(a.y, b.y, c.y);
|
492 |
+
return d;
|
493 |
+
}
|
494 |
+
|
495 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
496 |
+
|
497 |
+
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
|
498 |
+
{
|
499 |
+
uint32_t s = h0_h0(a);
|
500 |
+
uint2 d;
|
501 |
+
d.x = fma(s, b.x, c.x);
|
502 |
+
d.y = fma(s, b.y, c.y);
|
503 |
+
return d;
|
504 |
+
}
|
505 |
+
|
506 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
507 |
+
|
508 |
+
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
|
509 |
+
{
|
510 |
+
uint4 d;
|
511 |
+
d.x = fma(a.x, b.x, c.x);
|
512 |
+
d.y = fma(a.y, b.y, c.y);
|
513 |
+
d.z = fma(a.z, b.z, c.z);
|
514 |
+
d.w = fma(a.w, b.w, c.w);
|
515 |
+
return d;
|
516 |
+
}
|
517 |
+
|
518 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
519 |
+
|
520 |
+
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
|
521 |
+
{
|
522 |
+
uint32_t s = h0_h0(a);
|
523 |
+
uint4 d;
|
524 |
+
d.x = fma(s, b.x, c.x);
|
525 |
+
d.y = fma(s, b.y, c.y);
|
526 |
+
d.z = fma(s, b.z, c.z);
|
527 |
+
d.w = fma(s, b.w, c.w);
|
528 |
+
return d;
|
529 |
+
}
|
530 |
+
|
531 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
532 |
+
|
533 |
+
inline __device__ float fma(uint16_t a, uint16_t b, float fc)
|
534 |
+
{
|
535 |
+
float fa = half_to_float(a);
|
536 |
+
float fb = half_to_float(b);
|
537 |
+
return fa * fb + fc;
|
538 |
+
}
|
539 |
+
|
540 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
541 |
+
|
542 |
+
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
|
543 |
+
{
|
544 |
+
float2 fa = half2_to_float2(a);
|
545 |
+
float2 fb = half2_to_float2(b);
|
546 |
+
return fma(fa, fb, fc);
|
547 |
+
}
|
548 |
+
|
549 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
550 |
+
|
551 |
+
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
|
552 |
+
{
|
553 |
+
return fma(h0_h0(a), b, fc);
|
554 |
+
}
|
555 |
+
|
556 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
557 |
+
|
558 |
+
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
|
559 |
+
{
|
560 |
+
Float4_ fd;
|
561 |
+
fd.x = fma(a.x, b.x, fc.x);
|
562 |
+
fd.y = fma(a.y, b.y, fc.y);
|
563 |
+
return fd;
|
564 |
+
}
|
565 |
+
|
566 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
567 |
+
|
568 |
+
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
|
569 |
+
{
|
570 |
+
uint32_t s = h0_h0(a);
|
571 |
+
Float4_ fd;
|
572 |
+
fd.x = fma(s, b.x, fc.x);
|
573 |
+
fd.y = fma(s, b.y, fc.y);
|
574 |
+
return fd;
|
575 |
+
}
|
576 |
+
|
577 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
578 |
+
|
579 |
+
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
|
580 |
+
{
|
581 |
+
Float8_ fd;
|
582 |
+
fd.x = fma(a.x, b.x, fc.x);
|
583 |
+
fd.y = fma(a.y, b.y, fc.y);
|
584 |
+
fd.z = fma(a.z, b.z, fc.z);
|
585 |
+
fd.w = fma(a.w, b.w, fc.w);
|
586 |
+
return fd;
|
587 |
+
}
|
588 |
+
|
589 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
590 |
+
|
591 |
+
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
|
592 |
+
{
|
593 |
+
uint32_t s = h0_h0(a);
|
594 |
+
Float8_ fd;
|
595 |
+
fd.x = fma(s, b.x, fc.x);
|
596 |
+
fd.y = fma(s, b.y, fc.y);
|
597 |
+
fd.z = fma(s, b.z, fc.z);
|
598 |
+
fd.w = fma(s, b.w, fc.w);
|
599 |
+
return fd;
|
600 |
+
}
|
601 |
+
|
602 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
603 |
+
#ifdef ENABLE_BF16
|
604 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
605 |
+
{
|
606 |
+
return bf16hfma2(a, b, c);
|
607 |
+
}
|
608 |
+
|
609 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
610 |
+
|
611 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
612 |
+
{
|
613 |
+
return bf16hfma2(bf162bf162(a), b, c);
|
614 |
+
}
|
615 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
616 |
+
|
617 |
+
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
|
618 |
+
{
|
619 |
+
bf16_4_t d;
|
620 |
+
d.x = fma(a.x, b.x, c.x);
|
621 |
+
d.y = fma(a.y, b.y, c.y);
|
622 |
+
return d;
|
623 |
+
}
|
624 |
+
|
625 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
626 |
+
|
627 |
+
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
|
628 |
+
{
|
629 |
+
__nv_bfloat162 s = bf162bf162(a);
|
630 |
+
bf16_4_t d;
|
631 |
+
d.x = fma(s, b.x, c.x);
|
632 |
+
d.y = fma(s, b.y, c.y);
|
633 |
+
return d;
|
634 |
+
}
|
635 |
+
|
636 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
637 |
+
|
638 |
+
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
|
639 |
+
{
|
640 |
+
bf16_8_t d;
|
641 |
+
d.x = fma(a.x, b.x, c.x);
|
642 |
+
d.y = fma(a.y, b.y, c.y);
|
643 |
+
d.z = fma(a.z, b.z, c.z);
|
644 |
+
d.w = fma(a.w, b.w, c.w);
|
645 |
+
return d;
|
646 |
+
}
|
647 |
+
|
648 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
649 |
+
|
650 |
+
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
|
651 |
+
{
|
652 |
+
__nv_bfloat162 s = bf162bf162(a);
|
653 |
+
bf16_8_t d;
|
654 |
+
d.x = fma(s, b.x, c.x);
|
655 |
+
d.y = fma(s, b.y, c.y);
|
656 |
+
d.z = fma(s, b.z, c.z);
|
657 |
+
d.w = fma(s, b.w, c.w);
|
658 |
+
return d;
|
659 |
+
}
|
660 |
+
|
661 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
662 |
+
|
663 |
+
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
|
664 |
+
{
|
665 |
+
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
666 |
+
}
|
667 |
+
|
668 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
669 |
+
|
670 |
+
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
|
671 |
+
{
|
672 |
+
float2 fa = bf1622float2(a);
|
673 |
+
float2 fb = bf1622float2(b);
|
674 |
+
return fma(fa, fb, fc);
|
675 |
+
}
|
676 |
+
|
677 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
678 |
+
|
679 |
+
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
|
680 |
+
{
|
681 |
+
return fma(bf162bf162(a), b, fc);
|
682 |
+
}
|
683 |
+
|
684 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
685 |
+
|
686 |
+
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
|
687 |
+
{
|
688 |
+
Float4_ fd;
|
689 |
+
fd.x = fma(a.x, b.x, fc.x);
|
690 |
+
fd.y = fma(a.y, b.y, fc.y);
|
691 |
+
return fd;
|
692 |
+
}
|
693 |
+
|
694 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
695 |
+
|
696 |
+
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
|
697 |
+
{
|
698 |
+
__nv_bfloat162 s = bf162bf162(a);
|
699 |
+
Float4_ fd;
|
700 |
+
fd.x = fma(s, b.x, fc.x);
|
701 |
+
fd.y = fma(s, b.y, fc.y);
|
702 |
+
return fd;
|
703 |
+
}
|
704 |
+
|
705 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
706 |
+
|
707 |
+
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
|
708 |
+
{
|
709 |
+
Float8_ fd;
|
710 |
+
fd.x = fma(a.x, b.x, fc.x);
|
711 |
+
fd.y = fma(a.y, b.y, fc.y);
|
712 |
+
fd.z = fma(a.z, b.z, fc.z);
|
713 |
+
fd.w = fma(a.w, b.w, fc.w);
|
714 |
+
return fd;
|
715 |
+
}
|
716 |
+
|
717 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
718 |
+
|
719 |
+
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
|
720 |
+
{
|
721 |
+
__nv_bfloat162 s = bf162bf162(a);
|
722 |
+
Float8_ fd;
|
723 |
+
fd.x = fma(s, b.x, fc.x);
|
724 |
+
fd.y = fma(s, b.y, fc.y);
|
725 |
+
fd.z = fma(s, b.z, fc.z);
|
726 |
+
fd.w = fma(s, b.w, fc.w);
|
727 |
+
return fd;
|
728 |
+
}
|
729 |
+
#endif // ENABLE_BF16
|
730 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
731 |
+
|
732 |
+
template<typename Acc, typename A, typename B>
|
733 |
+
inline __device__ Acc mul(A a, B b)
|
734 |
+
{
|
735 |
+
return a * b;
|
736 |
+
}
|
737 |
+
|
738 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
739 |
+
|
740 |
+
template<>
|
741 |
+
inline __device__ float mul<float, float>(float a, float b)
|
742 |
+
{
|
743 |
+
return a * b;
|
744 |
+
}
|
745 |
+
|
746 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
747 |
+
|
748 |
+
template<>
|
749 |
+
inline __device__ float2 mul(float2 a, float2 b)
|
750 |
+
{
|
751 |
+
float2 c;
|
752 |
+
c.x = a.x * b.x;
|
753 |
+
c.y = a.y * b.y;
|
754 |
+
return c;
|
755 |
+
}
|
756 |
+
|
757 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
758 |
+
|
759 |
+
template<>
|
760 |
+
inline __device__ float2 mul(float a, float2 b)
|
761 |
+
{
|
762 |
+
float2 c;
|
763 |
+
c.x = a * b.x;
|
764 |
+
c.y = a * b.y;
|
765 |
+
return c;
|
766 |
+
}
|
767 |
+
|
768 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
769 |
+
|
770 |
+
template<>
|
771 |
+
inline __device__ float4 mul(float4 a, float4 b)
|
772 |
+
{
|
773 |
+
float4 c;
|
774 |
+
c.x = a.x * b.x;
|
775 |
+
c.y = a.y * b.y;
|
776 |
+
c.z = a.z * b.z;
|
777 |
+
c.w = a.w * b.w;
|
778 |
+
return c;
|
779 |
+
}
|
780 |
+
|
781 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
782 |
+
|
783 |
+
template<>
|
784 |
+
inline __device__ float4 mul(float a, float4 b)
|
785 |
+
{
|
786 |
+
float4 c;
|
787 |
+
c.x = a * b.x;
|
788 |
+
c.y = a * b.y;
|
789 |
+
c.z = a * b.z;
|
790 |
+
c.w = a * b.w;
|
791 |
+
return c;
|
792 |
+
}
|
793 |
+
|
794 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
795 |
+
|
796 |
+
template<>
|
797 |
+
inline __device__ Float8_ mul(float a, Float8_ b)
|
798 |
+
{
|
799 |
+
Float8_ c;
|
800 |
+
c.x = make_float2(a * b.x.x, a * b.x.y);
|
801 |
+
c.y = make_float2(a * b.y.x, a * b.y.y);
|
802 |
+
c.z = make_float2(a * b.z.x, a * b.z.y);
|
803 |
+
c.w = make_float2(a * b.w.x, a * b.w.y);
|
804 |
+
return c;
|
805 |
+
}
|
806 |
+
|
807 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
808 |
+
|
809 |
+
template<>
|
810 |
+
inline __device__ uint16_t mul(uint16_t a, uint16_t b)
|
811 |
+
{
|
812 |
+
uint16_t c;
|
813 |
+
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
814 |
+
return c;
|
815 |
+
}
|
816 |
+
|
817 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
818 |
+
|
819 |
+
template<>
|
820 |
+
inline __device__ uint32_t mul(uint32_t a, uint32_t b)
|
821 |
+
{
|
822 |
+
uint32_t c;
|
823 |
+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
824 |
+
return c;
|
825 |
+
}
|
826 |
+
|
827 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
828 |
+
|
829 |
+
template<>
|
830 |
+
inline __device__ uint32_t mul(uint16_t a, uint32_t b)
|
831 |
+
{
|
832 |
+
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
833 |
+
}
|
834 |
+
|
835 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
836 |
+
|
837 |
+
template<>
|
838 |
+
inline __device__ uint2 mul(uint2 a, uint2 b)
|
839 |
+
{
|
840 |
+
uint2 c;
|
841 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
842 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
843 |
+
return c;
|
844 |
+
}
|
845 |
+
|
846 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
847 |
+
|
848 |
+
template<>
|
849 |
+
inline __device__ uint2 mul(uint16_t a, uint2 b)
|
850 |
+
{
|
851 |
+
uint32_t s = h0_h0(a);
|
852 |
+
uint2 c;
|
853 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
854 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
855 |
+
return c;
|
856 |
+
}
|
857 |
+
|
858 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
859 |
+
|
860 |
+
template<>
|
861 |
+
inline __device__ uint4 mul(uint4 a, uint4 b)
|
862 |
+
{
|
863 |
+
uint4 c;
|
864 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
865 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
866 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
867 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
868 |
+
return c;
|
869 |
+
}
|
870 |
+
|
871 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
872 |
+
|
873 |
+
template<>
|
874 |
+
inline __device__ uint4 mul(uint16_t a, uint4 b)
|
875 |
+
{
|
876 |
+
uint32_t s = h0_h0(a);
|
877 |
+
uint4 c;
|
878 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
879 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
880 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
881 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
882 |
+
return c;
|
883 |
+
}
|
884 |
+
|
885 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
886 |
+
|
887 |
+
template<>
|
888 |
+
inline __device__ float mul(uint16_t a, uint16_t b)
|
889 |
+
{
|
890 |
+
float fa = half_to_float(a);
|
891 |
+
float fb = half_to_float(b);
|
892 |
+
return fa * fb;
|
893 |
+
}
|
894 |
+
|
895 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
896 |
+
|
897 |
+
template<>
|
898 |
+
inline __device__ float mul(uint16_t a, float b)
|
899 |
+
{
|
900 |
+
return half_to_float(a) * b;
|
901 |
+
}
|
902 |
+
|
903 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
904 |
+
|
905 |
+
template<>
|
906 |
+
inline __device__ float2 mul(uint32_t a, uint32_t b)
|
907 |
+
{
|
908 |
+
float2 fa = half2_to_float2(a);
|
909 |
+
float2 fb = half2_to_float2(b);
|
910 |
+
return mul<float2, float2, float2>(fa, fb);
|
911 |
+
}
|
912 |
+
|
913 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
914 |
+
|
915 |
+
template<>
|
916 |
+
inline __device__ float2 mul(uint16_t a, uint32_t b)
|
917 |
+
{
|
918 |
+
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
919 |
+
}
|
920 |
+
|
921 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
922 |
+
|
923 |
+
template<>
|
924 |
+
inline __device__ Float4_ mul(uint2 a, uint2 b)
|
925 |
+
{
|
926 |
+
Float4_ fc;
|
927 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
928 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
929 |
+
return fc;
|
930 |
+
}
|
931 |
+
|
932 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
933 |
+
|
934 |
+
template<>
|
935 |
+
inline __device__ Float4_ mul(uint16_t a, uint2 b)
|
936 |
+
{
|
937 |
+
uint32_t s = h0_h0(a);
|
938 |
+
Float4_ fc;
|
939 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
940 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
941 |
+
return fc;
|
942 |
+
}
|
943 |
+
|
944 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
945 |
+
|
946 |
+
template<>
|
947 |
+
inline __device__ Float8_ mul(uint4 a, uint4 b)
|
948 |
+
{
|
949 |
+
Float8_ fc;
|
950 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
951 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
952 |
+
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
953 |
+
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
954 |
+
return fc;
|
955 |
+
}
|
956 |
+
|
957 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
958 |
+
|
959 |
+
template<>
|
960 |
+
inline __device__ Float8_ mul(uint16_t a, uint4 b)
|
961 |
+
{
|
962 |
+
uint32_t s = h0_h0(a);
|
963 |
+
Float8_ fc;
|
964 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
965 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
966 |
+
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
967 |
+
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
968 |
+
return fc;
|
969 |
+
}
|
970 |
+
|
971 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
972 |
+
|
973 |
+
#ifdef ENABLE_BF16
|
974 |
+
template<>
|
975 |
+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
|
976 |
+
{
|
977 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
978 |
+
return __hmul(a, b);
|
979 |
+
#else
|
980 |
+
return bf16hmul(a, b);
|
981 |
+
#endif
|
982 |
+
}
|
983 |
+
|
984 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
985 |
+
|
986 |
+
template<>
|
987 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
|
988 |
+
{
|
989 |
+
return bf16hmul2(a, b);
|
990 |
+
}
|
991 |
+
|
992 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
993 |
+
|
994 |
+
template<>
|
995 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
|
996 |
+
{
|
997 |
+
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
998 |
+
}
|
999 |
+
|
1000 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1001 |
+
|
1002 |
+
template<>
|
1003 |
+
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
|
1004 |
+
{
|
1005 |
+
bf16_4_t c;
|
1006 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1007 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1008 |
+
return c;
|
1009 |
+
}
|
1010 |
+
|
1011 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1012 |
+
|
1013 |
+
template<>
|
1014 |
+
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
|
1015 |
+
{
|
1016 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1017 |
+
bf16_4_t c;
|
1018 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1019 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1020 |
+
return c;
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1024 |
+
|
1025 |
+
template<>
|
1026 |
+
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
|
1027 |
+
{
|
1028 |
+
bf16_8_t c;
|
1029 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1030 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1031 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
1032 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
1033 |
+
return c;
|
1034 |
+
}
|
1035 |
+
|
1036 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1037 |
+
|
1038 |
+
template<>
|
1039 |
+
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
|
1040 |
+
{
|
1041 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1042 |
+
bf16_8_t c;
|
1043 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1044 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1045 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
1046 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
1047 |
+
return c;
|
1048 |
+
}
|
1049 |
+
|
1050 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1051 |
+
|
1052 |
+
template<>
|
1053 |
+
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
|
1054 |
+
{
|
1055 |
+
float fa = (float)a;
|
1056 |
+
float fb = (float)b;
|
1057 |
+
return fa * fb;
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1061 |
+
|
1062 |
+
template<>
|
1063 |
+
inline __device__ float mul(__nv_bfloat16 a, float b)
|
1064 |
+
{
|
1065 |
+
return __bfloat162float(a) * b;
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1069 |
+
|
1070 |
+
template<>
|
1071 |
+
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
|
1072 |
+
{
|
1073 |
+
float2 fa = bf1622float2(a);
|
1074 |
+
float2 fb = bf1622float2(b);
|
1075 |
+
return mul<float2, float2, float2>(fa, fb);
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1079 |
+
|
1080 |
+
template<>
|
1081 |
+
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
|
1082 |
+
{
|
1083 |
+
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
1084 |
+
}
|
1085 |
+
|
1086 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1087 |
+
|
1088 |
+
template<>
|
1089 |
+
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
|
1090 |
+
{
|
1091 |
+
Float4_ fc;
|
1092 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1093 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1094 |
+
return fc;
|
1095 |
+
}
|
1096 |
+
|
1097 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1098 |
+
|
1099 |
+
template<>
|
1100 |
+
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
|
1101 |
+
{
|
1102 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1103 |
+
Float4_ fc;
|
1104 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1105 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1106 |
+
return fc;
|
1107 |
+
}
|
1108 |
+
|
1109 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1110 |
+
|
1111 |
+
template<>
|
1112 |
+
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
|
1113 |
+
{
|
1114 |
+
Float8_ fc;
|
1115 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1116 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1117 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
1118 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
1119 |
+
return fc;
|
1120 |
+
}
|
1121 |
+
|
1122 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1123 |
+
|
1124 |
+
template<>
|
1125 |
+
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
|
1126 |
+
{
|
1127 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1128 |
+
Float8_ fc;
|
1129 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1130 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1131 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
1132 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
1133 |
+
return fc;
|
1134 |
+
}
|
1135 |
+
#endif // ENABLE_BF16
|
1136 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1137 |
+
|
1138 |
+
inline __device__ float sum(float v)
|
1139 |
+
{
|
1140 |
+
return v;
|
1141 |
+
}
|
1142 |
+
|
1143 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1144 |
+
|
1145 |
+
inline __device__ float sum(float2 v)
|
1146 |
+
{
|
1147 |
+
return v.x + v.y;
|
1148 |
+
}
|
1149 |
+
|
1150 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1151 |
+
|
1152 |
+
inline __device__ float sum(float4 v)
|
1153 |
+
{
|
1154 |
+
return v.x + v.y + v.z + v.w;
|
1155 |
+
}
|
1156 |
+
|
1157 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1158 |
+
|
1159 |
+
#ifdef ENABLE_BF16
|
1160 |
+
inline __device__ float sum(__nv_bfloat162 v)
|
1161 |
+
{
|
1162 |
+
float2 vf = bf1622float2(v);
|
1163 |
+
return vf.x + vf.y;
|
1164 |
+
}
|
1165 |
+
|
1166 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1167 |
+
|
1168 |
+
inline __device__ float sum(bf16_4_t v)
|
1169 |
+
{
|
1170 |
+
return sum(v.x) + sum(v.y);
|
1171 |
+
}
|
1172 |
+
|
1173 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1174 |
+
|
1175 |
+
inline __device__ float sum(bf16_8_t v)
|
1176 |
+
{
|
1177 |
+
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
1178 |
+
}
|
1179 |
+
#endif // ENABLE_BF16
|
1180 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1181 |
+
|
1182 |
+
inline __device__ float sum(uint16_t v)
|
1183 |
+
{
|
1184 |
+
return half_to_float(v);
|
1185 |
+
}
|
1186 |
+
|
1187 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1188 |
+
|
1189 |
+
inline __device__ float sum(uint32_t v)
|
1190 |
+
{
|
1191 |
+
float2 tmp = half2_to_float2(v);
|
1192 |
+
return tmp.x + tmp.y;
|
1193 |
+
}
|
1194 |
+
|
1195 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1196 |
+
|
1197 |
+
inline __device__ float sum(uint2 v)
|
1198 |
+
{
|
1199 |
+
uint32_t c = add(v.x, v.y);
|
1200 |
+
return sum(c);
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1204 |
+
|
1205 |
+
inline __device__ float sum(uint4 v)
|
1206 |
+
{
|
1207 |
+
#if 1
|
1208 |
+
uint32_t c = add(v.x, v.y);
|
1209 |
+
c = add(c, v.z);
|
1210 |
+
c = add(c, v.w);
|
1211 |
+
#else
|
1212 |
+
uint32_t c = add(v.x, v.y);
|
1213 |
+
uint32_t d = add(v.z, v.w);
|
1214 |
+
c = add(c, d);
|
1215 |
+
#endif
|
1216 |
+
return sum(c);
|
1217 |
+
}
|
1218 |
+
|
1219 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1220 |
+
|
1221 |
+
inline __device__ float sum(Float4_ v)
|
1222 |
+
{
|
1223 |
+
return v.x.x + v.x.y + v.y.x + v.y.y;
|
1224 |
+
}
|
1225 |
+
|
1226 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1227 |
+
|
1228 |
+
inline __device__ float sum(Float8_ v)
|
1229 |
+
{
|
1230 |
+
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
1231 |
+
}
|
1232 |
+
|
1233 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1234 |
+
|
1235 |
+
template<typename T>
|
1236 |
+
inline __device__ float dot(T a, T b)
|
1237 |
+
{
|
1238 |
+
return sum(mul<T, T, T>(a, b));
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1242 |
+
|
1243 |
+
template<typename A, typename T>
|
1244 |
+
inline __device__ float dot(T a, T b)
|
1245 |
+
{
|
1246 |
+
return sum(mul<A, T, T>(a, b));
|
1247 |
+
}
|
1248 |
+
|
1249 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1250 |
+
|
1251 |
+
inline __device__ void zero(uint16_t& dst)
|
1252 |
+
{
|
1253 |
+
dst = uint16_t(0);
|
1254 |
+
}
|
1255 |
+
|
1256 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1257 |
+
|
1258 |
+
template<typename T>
|
1259 |
+
inline __device__ void zero(T& dst)
|
1260 |
+
{
|
1261 |
+
constexpr int WORDS = sizeof(T) / 4;
|
1262 |
+
union {
|
1263 |
+
T raw;
|
1264 |
+
uint32_t words[WORDS];
|
1265 |
+
} tmp;
|
1266 |
+
#pragma unroll
|
1267 |
+
for (int ii = 0; ii < WORDS; ++ii) {
|
1268 |
+
tmp.words[ii] = 0u;
|
1269 |
+
}
|
1270 |
+
dst = tmp.raw;
|
1271 |
+
}
|
1272 |
+
|
1273 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1274 |
+
|
1275 |
+
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base)
|
1276 |
+
{
|
1277 |
+
const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
|
1278 |
+
return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)};
|
1279 |
+
}
|
1280 |
+
|
1281 |
+
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
|
1282 |
+
{
|
1283 |
+
float2 rot_v;
|
1284 |
+
rot_v.x = coef.x * v.x - coef.y * v.y;
|
1285 |
+
rot_v.y = coef.x * v.y + coef.y * v.x;
|
1286 |
+
return rot_v;
|
1287 |
+
}
|
1288 |
+
|
1289 |
+
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
|
1290 |
+
{
|
1291 |
+
float2 fv = half2_to_float2(v);
|
1292 |
+
float2 rot_fv = rotary_embedding_transform(fv, coef);
|
1293 |
+
return float2_to_half2(rot_fv);
|
1294 |
+
}
|
1295 |
+
|
1296 |
+
#ifdef ENABLE_BF16
|
1297 |
+
inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
|
1298 |
+
{
|
1299 |
+
float2 fv = bf1622float2(v);
|
1300 |
+
float2 rot_fv = rotary_embedding_transform(fv, coef);
|
1301 |
+
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
|
1302 |
+
}
|
1303 |
+
#endif
|
1304 |
+
|
1305 |
+
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1306 |
+
{
|
1307 |
+
return;
|
1308 |
+
}
|
1309 |
+
|
1310 |
+
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1311 |
+
{
|
1312 |
+
return;
|
1313 |
+
}
|
1314 |
+
|
1315 |
+
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1316 |
+
{
|
1317 |
+
if (2 * tid >= rot_embed_dim) {
|
1318 |
+
return;
|
1319 |
+
}
|
1320 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1321 |
+
q = rotary_embedding_transform(q, coef);
|
1322 |
+
}
|
1323 |
+
|
1324 |
+
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1325 |
+
{
|
1326 |
+
if (2 * tid >= rot_embed_dim) {
|
1327 |
+
return;
|
1328 |
+
}
|
1329 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1330 |
+
q = rotary_embedding_transform(q, coef);
|
1331 |
+
k = rotary_embedding_transform(k, coef);
|
1332 |
+
}
|
1333 |
+
|
1334 |
+
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1335 |
+
{
|
1336 |
+
if (4 * tid >= rot_embed_dim) {
|
1337 |
+
return;
|
1338 |
+
}
|
1339 |
+
|
1340 |
+
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
1341 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1342 |
+
q_.x = rotary_embedding_transform(q_.x, coef0);
|
1343 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1344 |
+
q_.y = rotary_embedding_transform(q_.y, coef1);
|
1345 |
+
}
|
1346 |
+
|
1347 |
+
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1348 |
+
{
|
1349 |
+
if (4 * tid >= rot_embed_dim) {
|
1350 |
+
return;
|
1351 |
+
}
|
1352 |
+
|
1353 |
+
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
1354 |
+
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
|
1355 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1356 |
+
q_.x = rotary_embedding_transform(q_.x, coef0);
|
1357 |
+
k_.x = rotary_embedding_transform(k_.x, coef0);
|
1358 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1359 |
+
q_.y = rotary_embedding_transform(q_.y, coef1);
|
1360 |
+
k_.y = rotary_embedding_transform(k_.y, coef1);
|
1361 |
+
}
|
1362 |
+
|
1363 |
+
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1364 |
+
{
|
1365 |
+
if (2 * tid >= rot_embed_dim) {
|
1366 |
+
return;
|
1367 |
+
}
|
1368 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1369 |
+
q = rotary_embedding_transform(q, coef);
|
1370 |
+
}
|
1371 |
+
|
1372 |
+
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1373 |
+
{
|
1374 |
+
if (2 * tid >= rot_embed_dim) {
|
1375 |
+
return;
|
1376 |
+
}
|
1377 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1378 |
+
q = rotary_embedding_transform(q, coef);
|
1379 |
+
k = rotary_embedding_transform(k, coef);
|
1380 |
+
}
|
1381 |
+
|
1382 |
+
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1383 |
+
{
|
1384 |
+
if (4 * tid >= rot_embed_dim) {
|
1385 |
+
return;
|
1386 |
+
}
|
1387 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1388 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1389 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1390 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1391 |
+
}
|
1392 |
+
|
1393 |
+
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1394 |
+
{
|
1395 |
+
if (4 * tid >= rot_embed_dim) {
|
1396 |
+
return;
|
1397 |
+
}
|
1398 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1399 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1400 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1401 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1402 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1403 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1404 |
+
}
|
1405 |
+
|
1406 |
+
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1407 |
+
{
|
1408 |
+
if (8 * tid >= rot_embed_dim) {
|
1409 |
+
return;
|
1410 |
+
}
|
1411 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1412 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1413 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1414 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1415 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1416 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1417 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1418 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1419 |
+
}
|
1420 |
+
|
1421 |
+
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1422 |
+
{
|
1423 |
+
if (8 * tid >= rot_embed_dim) {
|
1424 |
+
return;
|
1425 |
+
}
|
1426 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1427 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1428 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1429 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1430 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1431 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1432 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1433 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1434 |
+
k.z = rotary_embedding_transform(k.z, coef2);
|
1435 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1436 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1437 |
+
k.w = rotary_embedding_transform(k.w, coef3);
|
1438 |
+
}
|
1439 |
+
|
1440 |
+
#ifdef ENABLE_BF16
|
1441 |
+
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1442 |
+
{
|
1443 |
+
if (2 * tid >= rot_embed_dim) {
|
1444 |
+
return;
|
1445 |
+
}
|
1446 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1447 |
+
q = rotary_embedding_transform(q, coef);
|
1448 |
+
}
|
1449 |
+
|
1450 |
+
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1451 |
+
{
|
1452 |
+
if (2 * tid >= rot_embed_dim) {
|
1453 |
+
return;
|
1454 |
+
}
|
1455 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1456 |
+
q = rotary_embedding_transform(q, coef);
|
1457 |
+
k = rotary_embedding_transform(k, coef);
|
1458 |
+
}
|
1459 |
+
|
1460 |
+
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1461 |
+
{
|
1462 |
+
if (4 * tid >= rot_embed_dim) {
|
1463 |
+
return;
|
1464 |
+
}
|
1465 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1466 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1467 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1468 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1469 |
+
}
|
1470 |
+
|
1471 |
+
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1472 |
+
{
|
1473 |
+
if (4 * tid >= rot_embed_dim) {
|
1474 |
+
return;
|
1475 |
+
}
|
1476 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1477 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1478 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1479 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1480 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1481 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1482 |
+
}
|
1483 |
+
|
1484 |
+
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1485 |
+
{
|
1486 |
+
if (8 * tid >= rot_embed_dim) {
|
1487 |
+
return;
|
1488 |
+
}
|
1489 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1490 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1491 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1492 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1493 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1494 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1495 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1496 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1497 |
+
}
|
1498 |
+
|
1499 |
+
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1500 |
+
{
|
1501 |
+
if (8 * tid >= rot_embed_dim) {
|
1502 |
+
return;
|
1503 |
+
}
|
1504 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1505 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1506 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1507 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1508 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1509 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1510 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1511 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1512 |
+
k.z = rotary_embedding_transform(k.z, coef2);
|
1513 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1514 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1515 |
+
k.w = rotary_embedding_transform(k.w, coef3);
|
1516 |
+
}
|
1517 |
+
#endif // ENABLE_BF16
|
1518 |
+
|
1519 |
+
template <typename T>
|
1520 |
+
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin)
|
1521 |
+
{
|
1522 |
+
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
|
1523 |
+
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
|
1524 |
+
return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])};
|
1525 |
+
}
|
1526 |
+
|
1527 |
+
// fp16 is special because we use uint16_t for reading the data, for backward compatibility.
|
1528 |
+
template <>
|
1529 |
+
inline __device__ float2 rotary_embedding_coefficient<uint16_t>(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1530 |
+
{
|
1531 |
+
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
|
1532 |
+
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
|
1533 |
+
return {float(reinterpret_cast<const __half*>(rotary_cos)[zid / 2]),
|
1534 |
+
float(reinterpret_cast<const __half*>(rotary_sin)[zid / 2])};
|
1535 |
+
}
|
1536 |
+
|
1537 |
+
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
|
1538 |
+
{
|
1539 |
+
return;
|
1540 |
+
}
|
1541 |
+
|
1542 |
+
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
|
1543 |
+
{
|
1544 |
+
return;
|
1545 |
+
}
|
1546 |
+
|
1547 |
+
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
|
1548 |
+
{
|
1549 |
+
if (2 * tid >= rot_embed_dim) {
|
1550 |
+
return;
|
1551 |
+
}
|
1552 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
1553 |
+
q = rotary_embedding_transform(q, coef);
|
1554 |
+
}
|
1555 |
+
|
1556 |
+
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
|
1557 |
+
{
|
1558 |
+
if (2 * tid >= rot_embed_dim) {
|
1559 |
+
return;
|
1560 |
+
}
|
1561 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
1562 |
+
q = rotary_embedding_transform(q, coef);
|
1563 |
+
k = rotary_embedding_transform(k, coef);
|
1564 |
+
}
|
1565 |
+
|
1566 |
+
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
|
1567 |
+
{
|
1568 |
+
if (4 * tid >= rot_embed_dim) {
|
1569 |
+
return;
|
1570 |
+
}
|
1571 |
+
|
1572 |
+
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
1573 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
1574 |
+
q_.x = rotary_embedding_transform(q_.x, coef0);
|
1575 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1576 |
+
q_.y = rotary_embedding_transform(q_.y, coef1);
|
1577 |
+
}
|
1578 |
+
|
1579 |
+
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
|
1580 |
+
{
|
1581 |
+
if (4 * tid >= rot_embed_dim) {
|
1582 |
+
return;
|
1583 |
+
}
|
1584 |
+
|
1585 |
+
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
1586 |
+
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
|
1587 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
1588 |
+
q_.x = rotary_embedding_transform(q_.x, coef0);
|
1589 |
+
k_.x = rotary_embedding_transform(k_.x, coef0);
|
1590 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1591 |
+
q_.y = rotary_embedding_transform(q_.y, coef1);
|
1592 |
+
k_.y = rotary_embedding_transform(k_.y, coef1);
|
1593 |
+
}
|
1594 |
+
|
1595 |
+
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1596 |
+
{
|
1597 |
+
if (2 * tid >= rot_embed_dim) {
|
1598 |
+
return;
|
1599 |
+
}
|
1600 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
1601 |
+
q = rotary_embedding_transform(q, coef);
|
1602 |
+
}
|
1603 |
+
|
1604 |
+
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1605 |
+
{
|
1606 |
+
if (2 * tid >= rot_embed_dim) {
|
1607 |
+
return;
|
1608 |
+
}
|
1609 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
1610 |
+
q = rotary_embedding_transform(q, coef);
|
1611 |
+
k = rotary_embedding_transform(k, coef);
|
1612 |
+
}
|
1613 |
+
|
1614 |
+
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1615 |
+
{
|
1616 |
+
if (4 * tid >= rot_embed_dim) {
|
1617 |
+
return;
|
1618 |
+
}
|
1619 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
1620 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1621 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1622 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1623 |
+
}
|
1624 |
+
|
1625 |
+
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1626 |
+
{
|
1627 |
+
if (4 * tid >= rot_embed_dim) {
|
1628 |
+
return;
|
1629 |
+
}
|
1630 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
1631 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1632 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1633 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1634 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1635 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1636 |
+
}
|
1637 |
+
|
1638 |
+
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1639 |
+
{
|
1640 |
+
if (8 * tid >= rot_embed_dim) {
|
1641 |
+
return;
|
1642 |
+
}
|
1643 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
1644 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1645 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1646 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1647 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
1648 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1649 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
1650 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1651 |
+
}
|
1652 |
+
|
1653 |
+
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
|
1654 |
+
{
|
1655 |
+
if (8 * tid >= rot_embed_dim) {
|
1656 |
+
return;
|
1657 |
+
}
|
1658 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
1659 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1660 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1661 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1662 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1663 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1664 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
1665 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1666 |
+
k.z = rotary_embedding_transform(k.z, coef2);
|
1667 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
1668 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1669 |
+
k.w = rotary_embedding_transform(k.w, coef3);
|
1670 |
+
}
|
1671 |
+
|
1672 |
+
#ifdef ENABLE_BF16
|
1673 |
+
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
|
1674 |
+
{
|
1675 |
+
if (2 * tid >= rot_embed_dim) {
|
1676 |
+
return;
|
1677 |
+
}
|
1678 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
1679 |
+
q = rotary_embedding_transform(q, coef);
|
1680 |
+
}
|
1681 |
+
|
1682 |
+
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
|
1683 |
+
{
|
1684 |
+
if (2 * tid >= rot_embed_dim) {
|
1685 |
+
return;
|
1686 |
+
}
|
1687 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
|
1688 |
+
q = rotary_embedding_transform(q, coef);
|
1689 |
+
k = rotary_embedding_transform(k, coef);
|
1690 |
+
}
|
1691 |
+
|
1692 |
+
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
|
1693 |
+
{
|
1694 |
+
if (4 * tid >= rot_embed_dim) {
|
1695 |
+
return;
|
1696 |
+
}
|
1697 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
1698 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1699 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1700 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1701 |
+
}
|
1702 |
+
|
1703 |
+
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
|
1704 |
+
{
|
1705 |
+
if (4 * tid >= rot_embed_dim) {
|
1706 |
+
return;
|
1707 |
+
}
|
1708 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
|
1709 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1710 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1711 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1712 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1713 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1714 |
+
}
|
1715 |
+
|
1716 |
+
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
|
1717 |
+
{
|
1718 |
+
if (8 * tid >= rot_embed_dim) {
|
1719 |
+
return;
|
1720 |
+
}
|
1721 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
1722 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1723 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1724 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1725 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
1726 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1727 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
1728 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1729 |
+
}
|
1730 |
+
|
1731 |
+
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
|
1732 |
+
{
|
1733 |
+
if (8 * tid >= rot_embed_dim) {
|
1734 |
+
return;
|
1735 |
+
}
|
1736 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
|
1737 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1738 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1739 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
|
1740 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1741 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1742 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
|
1743 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1744 |
+
k.z = rotary_embedding_transform(k.z, coef2);
|
1745 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
|
1746 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1747 |
+
k.w = rotary_embedding_transform(k.w, coef3);
|
1748 |
+
}
|
1749 |
+
#endif // ENABLE_BF16
|
1750 |
+
|
1751 |
+
template<typename Vec_T, typename T>
|
1752 |
+
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
|
1753 |
+
|
1754 |
+
template<>
|
1755 |
+
__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
|
1756 |
+
{
|
1757 |
+
return;
|
1758 |
+
}
|
1759 |
+
|
1760 |
+
template<>
|
1761 |
+
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1762 |
+
{
|
1763 |
+
union {
|
1764 |
+
uint32_t u32;
|
1765 |
+
uint16_t u16[2];
|
1766 |
+
} tmp;
|
1767 |
+
tmp.u16[0] = smem[transpose_idx];
|
1768 |
+
tmp.u16[1] = smem[smem_pitch + transpose_idx];
|
1769 |
+
|
1770 |
+
vec = tmp.u32;
|
1771 |
+
}
|
1772 |
+
|
1773 |
+
template<>
|
1774 |
+
__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1775 |
+
{
|
1776 |
+
union {
|
1777 |
+
uint32_t u32;
|
1778 |
+
uint16_t u16[2];
|
1779 |
+
} tmp_1, tmp_2;
|
1780 |
+
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
|
1781 |
+
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
|
1782 |
+
|
1783 |
+
union {
|
1784 |
+
uint2 u32x2;
|
1785 |
+
uint16_t u16[4];
|
1786 |
+
} tmp_3;
|
1787 |
+
tmp_3.u16[0] = tmp_1.u16[0];
|
1788 |
+
tmp_3.u16[1] = tmp_2.u16[0];
|
1789 |
+
tmp_3.u16[2] = tmp_1.u16[1];
|
1790 |
+
tmp_3.u16[3] = tmp_2.u16[1];
|
1791 |
+
|
1792 |
+
vec = tmp_3.u32x2;
|
1793 |
+
}
|
1794 |
+
|
1795 |
+
template<>
|
1796 |
+
__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1797 |
+
{
|
1798 |
+
union {
|
1799 |
+
uint64_t u64;
|
1800 |
+
uint16_t u16[4];
|
1801 |
+
} tmp_1, tmp_2;
|
1802 |
+
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
|
1803 |
+
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
|
1804 |
+
|
1805 |
+
union {
|
1806 |
+
uint4 u32x4;
|
1807 |
+
uint16_t u16[8];
|
1808 |
+
} tmp_3;
|
1809 |
+
tmp_3.u16[0] = tmp_1.u16[0];
|
1810 |
+
tmp_3.u16[1] = tmp_2.u16[0];
|
1811 |
+
tmp_3.u16[2] = tmp_1.u16[1];
|
1812 |
+
tmp_3.u16[3] = tmp_2.u16[1];
|
1813 |
+
tmp_3.u16[4] = tmp_1.u16[2];
|
1814 |
+
tmp_3.u16[5] = tmp_2.u16[2];
|
1815 |
+
tmp_3.u16[6] = tmp_1.u16[3];
|
1816 |
+
tmp_3.u16[7] = tmp_2.u16[3];
|
1817 |
+
|
1818 |
+
vec = tmp_3.u32x4;
|
1819 |
+
}
|
1820 |
+
|
1821 |
+
#ifdef ENABLE_BF16
|
1822 |
+
template<>
|
1823 |
+
__device__ __inline__ void
|
1824 |
+
vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1825 |
+
{
|
1826 |
+
union {
|
1827 |
+
uint32_t u32;
|
1828 |
+
__nv_bfloat16 bf16[2];
|
1829 |
+
} tmp_1, tmp_2;
|
1830 |
+
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
|
1831 |
+
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
|
1832 |
+
|
1833 |
+
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
|
1834 |
+
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
|
1835 |
+
}
|
1836 |
+
|
1837 |
+
template<>
|
1838 |
+
__device__ __inline__ void
|
1839 |
+
vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1840 |
+
{
|
1841 |
+
union {
|
1842 |
+
uint64_t u64;
|
1843 |
+
__nv_bfloat16 bf16[4];
|
1844 |
+
} tmp_1, tmp_2;
|
1845 |
+
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
|
1846 |
+
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
|
1847 |
+
|
1848 |
+
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
|
1849 |
+
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
|
1850 |
+
vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
|
1851 |
+
vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
|
1852 |
+
}
|
1853 |
+
#endif // ENABLE_BF16
|
1854 |
+
|
1855 |
+
template<>
|
1856 |
+
__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
|
1857 |
+
{
|
1858 |
+
vec.x = smem[transpose_idx];
|
1859 |
+
vec.z = smem[transpose_idx + 1];
|
1860 |
+
vec.y = smem[smem_pitch + transpose_idx];
|
1861 |
+
vec.w = smem[smem_pitch + transpose_idx + 1];
|
1862 |
+
}
|
1863 |
+
|
1864 |
+
template<>
|
1865 |
+
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
|
1866 |
+
{
|
1867 |
+
union {
|
1868 |
+
uint32_t u32;
|
1869 |
+
half u16[2];
|
1870 |
+
} tmp;
|
1871 |
+
tmp.u16[0] = smem[transpose_idx];
|
1872 |
+
tmp.u16[1] = smem[smem_pitch + transpose_idx];
|
1873 |
+
|
1874 |
+
vec = tmp.u32;
|
1875 |
+
}
|
1876 |
+
|
1877 |
+
#ifdef ENABLE_BF16
|
1878 |
+
template<>
|
1879 |
+
__device__ __inline__ void
|
1880 |
+
vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1881 |
+
{
|
1882 |
+
vec.x = smem[transpose_idx];
|
1883 |
+
vec.y = smem[smem_pitch + transpose_idx];
|
1884 |
+
}
|
1885 |
+
#endif
|
1886 |
+
|
1887 |
+
template<>
|
1888 |
+
__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
|
1889 |
+
{
|
1890 |
+
vec.x = smem[transpose_idx];
|
1891 |
+
vec.y = smem[smem_pitch + transpose_idx];
|
1892 |
+
}
|
1893 |
+
|
1894 |
+
template<typename Vec_T, typename T>
|
1895 |
+
__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
|
1896 |
+
|
1897 |
+
template<>
|
1898 |
+
__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
|
1899 |
+
{
|
1900 |
+
return;
|
1901 |
+
}
|
1902 |
+
|
1903 |
+
template<>
|
1904 |
+
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1905 |
+
{
|
1906 |
+
union {
|
1907 |
+
uint64_t u64;
|
1908 |
+
uint16_t u16[4];
|
1909 |
+
} tmp_1, tmp_2;
|
1910 |
+
|
1911 |
+
union {
|
1912 |
+
uint4 u32x4;
|
1913 |
+
uint16_t u16[8];
|
1914 |
+
} tmp_3;
|
1915 |
+
tmp_3.u32x4 = vec;
|
1916 |
+
tmp_1.u16[0] = tmp_3.u16[0];
|
1917 |
+
tmp_2.u16[0] = tmp_3.u16[1];
|
1918 |
+
tmp_1.u16[1] = tmp_3.u16[2];
|
1919 |
+
tmp_2.u16[1] = tmp_3.u16[3];
|
1920 |
+
tmp_1.u16[2] = tmp_3.u16[4];
|
1921 |
+
tmp_2.u16[2] = tmp_3.u16[5];
|
1922 |
+
tmp_1.u16[3] = tmp_3.u16[6];
|
1923 |
+
tmp_2.u16[3] = tmp_3.u16[7];
|
1924 |
+
|
1925 |
+
*reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
|
1926 |
+
*reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
|
1927 |
+
}
|
1928 |
+
|
1929 |
+
template<>
|
1930 |
+
__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1931 |
+
{
|
1932 |
+
union {
|
1933 |
+
uint32_t u32;
|
1934 |
+
uint16_t u16[2];
|
1935 |
+
} tmp_1, tmp_2;
|
1936 |
+
|
1937 |
+
union {
|
1938 |
+
uint2 u32x2;
|
1939 |
+
uint16_t u16[4];
|
1940 |
+
} tmp_3;
|
1941 |
+
tmp_3.u32x2 = vec;
|
1942 |
+
tmp_1.u16[0] = tmp_3.u16[0];
|
1943 |
+
tmp_2.u16[0] = tmp_3.u16[1];
|
1944 |
+
tmp_1.u16[1] = tmp_3.u16[2];
|
1945 |
+
tmp_2.u16[1] = tmp_3.u16[3];
|
1946 |
+
|
1947 |
+
*reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
|
1948 |
+
*reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
|
1949 |
+
}
|
1950 |
+
|
1951 |
+
template<>
|
1952 |
+
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1953 |
+
{
|
1954 |
+
union {
|
1955 |
+
uint32_t u32;
|
1956 |
+
uint16_t u16[2];
|
1957 |
+
} tmp;
|
1958 |
+
tmp.u32 = vec;
|
1959 |
+
|
1960 |
+
smem[transpose_idx] = tmp.u16[0];
|
1961 |
+
smem[smem_pitch + transpose_idx] = tmp.u16[1];
|
1962 |
+
}
|
1963 |
+
|
1964 |
+
template<>
|
1965 |
+
__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
|
1966 |
+
{
|
1967 |
+
smem[transpose_idx] = vec.x;
|
1968 |
+
smem[transpose_idx + 1] = vec.z;
|
1969 |
+
smem[smem_pitch + transpose_idx] = vec.y;
|
1970 |
+
smem[smem_pitch + transpose_idx + 1] = vec.w;
|
1971 |
+
}
|
1972 |
+
|
1973 |
+
template<>
|
1974 |
+
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
|
1975 |
+
{
|
1976 |
+
union {
|
1977 |
+
uint32_t u32;
|
1978 |
+
half u16[2];
|
1979 |
+
} tmp;
|
1980 |
+
|
1981 |
+
tmp.u32 = vec;
|
1982 |
+
smem[transpose_idx] = tmp.u16[0];
|
1983 |
+
smem[smem_pitch + transpose_idx] = tmp.u16[1];
|
1984 |
+
}
|
1985 |
+
|
1986 |
+
#ifdef ENABLE_BF16
|
1987 |
+
template<>
|
1988 |
+
__device__ __inline__ void
|
1989 |
+
write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1990 |
+
{
|
1991 |
+
smem[transpose_idx] = vec.x;
|
1992 |
+
smem[smem_pitch + transpose_idx] = vec.y;
|
1993 |
+
}
|
1994 |
+
|
1995 |
+
template<>
|
1996 |
+
__device__ __inline__ void
|
1997 |
+
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1998 |
+
{
|
1999 |
+
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
|
2000 |
+
}
|
2001 |
+
|
2002 |
+
template<>
|
2003 |
+
__device__ __inline__ void
|
2004 |
+
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
2005 |
+
{
|
2006 |
+
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
|
2007 |
+
}
|
2008 |
+
#endif
|
2009 |
+
|
2010 |
+
template<>
|
2011 |
+
__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
|
2012 |
+
{
|
2013 |
+
smem[transpose_idx] = vec.x;
|
2014 |
+
smem[smem_pitch + transpose_idx] = vec.y;
|
2015 |
+
}
|
2016 |
+
|
2017 |
+
} // namespace mmha
|