Upload 62 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- rscd/models/backbones/lib_mamba/__init__.py +58 -0
- rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc +0 -0
- rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc +0 -0
- rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc +0 -0
- rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc +0 -0
- rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc +0 -0
- rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc +0 -0
- rscd/models/backbones/lib_mamba/csm_triton.py +644 -0
- rscd/models/backbones/lib_mamba/csm_tritonk2.py +899 -0
- rscd/models/backbones/lib_mamba/csms6s.py +266 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md +97 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so +3 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps +3 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log +4 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja +35 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o +3 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o +3 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o +3 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh +50 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp +354 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh +306 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh +203 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh +302 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh +200 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp +341 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h +84 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh +344 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu +8 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu +8 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu +9 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh +238 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp +367 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh +323 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu +11 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu +11 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh +211 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp +363 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh +403 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h +90 -0
- rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h +210 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o filter=lfs diff=lfs merge=lfs -text
|
rscd/models/backbones/lib_mamba/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import partial
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .vmamba import VSSM
|
| 6 |
+
from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_vssm_model(config, **kwargs):
|
| 10 |
+
model_type = config.MODEL.TYPE
|
| 11 |
+
if model_type in ["vssm"]:
|
| 12 |
+
model = VSSM(
|
| 13 |
+
patch_size=config.MODEL.VSSM.PATCH_SIZE,
|
| 14 |
+
in_chans=config.MODEL.VSSM.IN_CHANS,
|
| 15 |
+
num_classes=config.MODEL.NUM_CLASSES,
|
| 16 |
+
depths=config.MODEL.VSSM.DEPTHS,
|
| 17 |
+
dims=config.MODEL.VSSM.EMBED_DIM,
|
| 18 |
+
# ===================
|
| 19 |
+
ssm_d_state=config.MODEL.VSSM.SSM_D_STATE,
|
| 20 |
+
ssm_ratio=config.MODEL.VSSM.SSM_RATIO,
|
| 21 |
+
ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO,
|
| 22 |
+
ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)),
|
| 23 |
+
ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER,
|
| 24 |
+
ssm_conv=config.MODEL.VSSM.SSM_CONV,
|
| 25 |
+
ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS,
|
| 26 |
+
ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE,
|
| 27 |
+
ssm_init=config.MODEL.VSSM.SSM_INIT,
|
| 28 |
+
forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE,
|
| 29 |
+
# ===================
|
| 30 |
+
mlp_ratio=config.MODEL.VSSM.MLP_RATIO,
|
| 31 |
+
mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER,
|
| 32 |
+
mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE,
|
| 33 |
+
# ===================
|
| 34 |
+
drop_path_rate=config.MODEL.DROP_PATH_RATE,
|
| 35 |
+
patch_norm=config.MODEL.VSSM.PATCH_NORM,
|
| 36 |
+
norm_layer=config.MODEL.VSSM.NORM_LAYER,
|
| 37 |
+
downsample_version=config.MODEL.VSSM.DOWNSAMPLE,
|
| 38 |
+
patchembed_version=config.MODEL.VSSM.PATCHEMBED,
|
| 39 |
+
gmlp=config.MODEL.VSSM.GMLP,
|
| 40 |
+
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
|
| 41 |
+
# ===================
|
| 42 |
+
posembed=config.MODEL.VSSM.POSEMBED,
|
| 43 |
+
imgsize=config.DATA.IMG_SIZE,
|
| 44 |
+
)
|
| 45 |
+
return model
|
| 46 |
+
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_model(config, is_pretrain=False):
|
| 51 |
+
model = None
|
| 52 |
+
if model is None:
|
| 53 |
+
model = build_vssm_model(config, is_pretrain)
|
| 54 |
+
return model
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc
ADDED
|
Binary file (8.38 kB). View file
|
|
|
rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc
ADDED
|
Binary file (46.4 kB). View file
|
|
|
rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc
ADDED
|
Binary file (40.6 kB). View file
|
|
|
rscd/models/backbones/lib_mamba/csm_triton.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
WITH_TRITON = True
|
| 5 |
+
# WITH_TRITON = False
|
| 6 |
+
try:
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
except:
|
| 10 |
+
WITH_TRITON = False
|
| 11 |
+
warnings.warn("Triton not installed, fall back to pytorch implements.")
|
| 12 |
+
|
| 13 |
+
# to make sure cached_property can be loaded for triton
|
| 14 |
+
if WITH_TRITON:
|
| 15 |
+
try:
|
| 16 |
+
from functools import cached_property
|
| 17 |
+
except:
|
| 18 |
+
warnings.warn("if you are using py37, add this line to functools.py: "
|
| 19 |
+
"cached_property = lambda func: property(lru_cache()(func))")
|
| 20 |
+
|
| 21 |
+
# torch implementation ========================================
|
| 22 |
+
def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
|
| 23 |
+
if in_channel_first:
|
| 24 |
+
B, C, H, W = x.shape
|
| 25 |
+
if scans == 0:
|
| 26 |
+
y = x.new_empty((B, 4, C, H * W))
|
| 27 |
+
y[:, 0, :, :] = x.flatten(2, 3)
|
| 28 |
+
y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
|
| 29 |
+
y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
|
| 30 |
+
elif scans == 1:
|
| 31 |
+
y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
|
| 32 |
+
elif scans == 2:
|
| 33 |
+
y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
|
| 34 |
+
y = torch.cat([y, y.flip(dims=[-1])], dim=1)
|
| 35 |
+
else:
|
| 36 |
+
B, H, W, C = x.shape
|
| 37 |
+
if scans == 0:
|
| 38 |
+
y = x.new_empty((B, H * W, 4, C))
|
| 39 |
+
y[:, :, 0, :] = x.flatten(1, 2)
|
| 40 |
+
y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
|
| 41 |
+
y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
|
| 42 |
+
elif scans == 1:
|
| 43 |
+
y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
|
| 44 |
+
elif scans == 2:
|
| 45 |
+
y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
|
| 46 |
+
y = torch.cat([y, y.flip(dims=[1])], dim=2)
|
| 47 |
+
|
| 48 |
+
if in_channel_first and (not out_channel_first):
|
| 49 |
+
y = y.permute(0, 3, 1, 2).contiguous()
|
| 50 |
+
elif (not in_channel_first) and out_channel_first:
|
| 51 |
+
y = y.permute(0, 2, 3, 1).contiguous()
|
| 52 |
+
|
| 53 |
+
return y
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
|
| 57 |
+
if out_channel_first:
|
| 58 |
+
B, K, D, H, W = y.shape
|
| 59 |
+
y = y.view(B, K, D, -1)
|
| 60 |
+
if scans == 0:
|
| 61 |
+
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
| 62 |
+
y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
|
| 63 |
+
elif scans == 1:
|
| 64 |
+
y = y.sum(1)
|
| 65 |
+
elif scans == 2:
|
| 66 |
+
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
| 67 |
+
y = y.sum(1)
|
| 68 |
+
else:
|
| 69 |
+
B, H, W, K, D = y.shape
|
| 70 |
+
y = y.view(B, -1, K, D)
|
| 71 |
+
if scans == 0:
|
| 72 |
+
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
|
| 73 |
+
y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
|
| 74 |
+
elif scans == 1:
|
| 75 |
+
y = y.sum(2)
|
| 76 |
+
elif scans == 2:
|
| 77 |
+
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
|
| 78 |
+
y = y.sum(2)
|
| 79 |
+
|
| 80 |
+
if in_channel_first and (not out_channel_first):
|
| 81 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 82 |
+
elif (not in_channel_first) and out_channel_first:
|
| 83 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 84 |
+
|
| 85 |
+
return y
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
|
| 89 |
+
if in_channel_first:
|
| 90 |
+
B, _, C, H, W = x.shape
|
| 91 |
+
if scans == 0:
|
| 92 |
+
y = torch.stack([
|
| 93 |
+
x[:, 0].flatten(2, 3),
|
| 94 |
+
x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
|
| 95 |
+
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
|
| 96 |
+
torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
|
| 97 |
+
], dim=1)
|
| 98 |
+
elif scans == 1:
|
| 99 |
+
y = x.flatten(2, 3)
|
| 100 |
+
elif scans == 2:
|
| 101 |
+
y = torch.stack([
|
| 102 |
+
x[:, 0].flatten(2, 3),
|
| 103 |
+
x[:, 1].flatten(2, 3),
|
| 104 |
+
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
|
| 105 |
+
torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
|
| 106 |
+
], dim=1)
|
| 107 |
+
else:
|
| 108 |
+
B, H, W, _, C = x.shape
|
| 109 |
+
if scans == 0:
|
| 110 |
+
y = torch.stack([
|
| 111 |
+
x[:, :, :, 0].flatten(1, 2),
|
| 112 |
+
x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
|
| 113 |
+
torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
|
| 114 |
+
torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
|
| 115 |
+
], dim=2)
|
| 116 |
+
elif scans == 1:
|
| 117 |
+
y = x.flatten(1, 2)
|
| 118 |
+
elif scans == 2:
|
| 119 |
+
y = torch.stack([
|
| 120 |
+
x[:, 0].flatten(1, 2),
|
| 121 |
+
x[:, 1].flatten(1, 2),
|
| 122 |
+
torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
|
| 123 |
+
torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
|
| 124 |
+
], dim=2)
|
| 125 |
+
|
| 126 |
+
if in_channel_first and (not out_channel_first):
|
| 127 |
+
y = y.permute(0, 3, 1, 2).contiguous()
|
| 128 |
+
elif (not in_channel_first) and out_channel_first:
|
| 129 |
+
y = y.permute(0, 2, 3, 1).contiguous()
|
| 130 |
+
|
| 131 |
+
return y
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
|
| 135 |
+
if out_channel_first:
|
| 136 |
+
B, K, D, H, W = y.shape
|
| 137 |
+
y = y.view(B, K, D, -1)
|
| 138 |
+
if scans == 0:
|
| 139 |
+
y = torch.stack([
|
| 140 |
+
y[:, 0],
|
| 141 |
+
y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
|
| 142 |
+
torch.flip(y[:, 2], dims=[-1]),
|
| 143 |
+
torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
|
| 144 |
+
], dim=1)
|
| 145 |
+
elif scans == 1:
|
| 146 |
+
y = y
|
| 147 |
+
elif scans == 2:
|
| 148 |
+
y = torch.stack([
|
| 149 |
+
y[:, 0],
|
| 150 |
+
y[:, 1],
|
| 151 |
+
torch.flip(y[:, 2], dims=[-1]),
|
| 152 |
+
torch.flip(y[:, 3], dims=[-1]),
|
| 153 |
+
], dim=1)
|
| 154 |
+
else:
|
| 155 |
+
B, H, W, _, D = y.shape
|
| 156 |
+
y = y.view(B, -1, K, D)
|
| 157 |
+
if scans == 0:
|
| 158 |
+
y = torch.stack([
|
| 159 |
+
y[:, :, 0],
|
| 160 |
+
y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
|
| 161 |
+
torch.flip(y[:, :, 2], dims=[1]),
|
| 162 |
+
torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
|
| 163 |
+
], dim=2)
|
| 164 |
+
elif scans == 1:
|
| 165 |
+
y = y
|
| 166 |
+
elif scans == 2:
|
| 167 |
+
y = torch.stack([
|
| 168 |
+
y[:, :, 0],
|
| 169 |
+
y[:, :, 1],
|
| 170 |
+
torch.flip(y[:, :, 2], dims=[1]),
|
| 171 |
+
torch.flip(y[:, :, 3], dims=[1]),
|
| 172 |
+
], dim=2)
|
| 173 |
+
|
| 174 |
+
if out_channel_first and (not in_channel_first):
|
| 175 |
+
y = y.permute(0, 3, 1, 2).contiguous()
|
| 176 |
+
elif (not out_channel_first) and in_channel_first:
|
| 177 |
+
y = y.permute(0, 2, 3, 1).contiguous()
|
| 178 |
+
|
| 179 |
+
return y
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class CrossScanF(torch.autograd.Function):
|
| 183 |
+
@staticmethod
|
| 184 |
+
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
|
| 185 |
+
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 186 |
+
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
|
| 187 |
+
ctx.in_channel_first = in_channel_first
|
| 188 |
+
ctx.out_channel_first = out_channel_first
|
| 189 |
+
ctx.one_by_one = one_by_one
|
| 190 |
+
ctx.scans = scans
|
| 191 |
+
|
| 192 |
+
if one_by_one:
|
| 193 |
+
B, K, C, H, W = x.shape
|
| 194 |
+
if not in_channel_first:
|
| 195 |
+
B, H, W, K, C = x.shape
|
| 196 |
+
else:
|
| 197 |
+
B, C, H, W = x.shape
|
| 198 |
+
if not in_channel_first:
|
| 199 |
+
B, H, W, C = x.shape
|
| 200 |
+
ctx.shape = (B, C, H, W)
|
| 201 |
+
|
| 202 |
+
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
|
| 203 |
+
y = _fn(x, in_channel_first, out_channel_first, scans)
|
| 204 |
+
|
| 205 |
+
return y
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def backward(ctx, ys: torch.Tensor):
|
| 209 |
+
# out: (b, k, d, l)
|
| 210 |
+
in_channel_first = ctx.in_channel_first
|
| 211 |
+
out_channel_first = ctx.out_channel_first
|
| 212 |
+
one_by_one = ctx.one_by_one
|
| 213 |
+
scans = ctx.scans
|
| 214 |
+
B, C, H, W = ctx.shape
|
| 215 |
+
|
| 216 |
+
ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
|
| 217 |
+
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
|
| 218 |
+
y = _fn(ys, in_channel_first, out_channel_first, scans)
|
| 219 |
+
|
| 220 |
+
if one_by_one:
|
| 221 |
+
y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
|
| 222 |
+
else:
|
| 223 |
+
y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
|
| 224 |
+
|
| 225 |
+
return y, None, None, None, None
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class CrossMergeF(torch.autograd.Function):
|
| 229 |
+
@staticmethod
|
| 230 |
+
def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
|
| 231 |
+
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 232 |
+
# y: (B, 4, C, H * W) | (B, H * W, 4, C)
|
| 233 |
+
ctx.in_channel_first = in_channel_first
|
| 234 |
+
ctx.out_channel_first = out_channel_first
|
| 235 |
+
ctx.one_by_one = one_by_one
|
| 236 |
+
ctx.scans = scans
|
| 237 |
+
|
| 238 |
+
B, K, C, H, W = ys.shape
|
| 239 |
+
if not out_channel_first:
|
| 240 |
+
B, H, W, K, C = ys.shape
|
| 241 |
+
ctx.shape = (B, C, H, W)
|
| 242 |
+
|
| 243 |
+
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
|
| 244 |
+
y = _fn(ys, in_channel_first, out_channel_first, scans)
|
| 245 |
+
|
| 246 |
+
return y
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def backward(ctx, x: torch.Tensor):
|
| 250 |
+
# B, D, L = x.shape
|
| 251 |
+
# out: (b, k, d, h, w)
|
| 252 |
+
in_channel_first = ctx.in_channel_first
|
| 253 |
+
out_channel_first = ctx.out_channel_first
|
| 254 |
+
one_by_one = ctx.one_by_one
|
| 255 |
+
scans = ctx.scans
|
| 256 |
+
B, C, H, W = ctx.shape
|
| 257 |
+
|
| 258 |
+
if not one_by_one:
|
| 259 |
+
if in_channel_first:
|
| 260 |
+
x = x.view(B, C, H, W)
|
| 261 |
+
else:
|
| 262 |
+
x = x.view(B, H, W, C)
|
| 263 |
+
else:
|
| 264 |
+
if in_channel_first:
|
| 265 |
+
x = x.view(B, 4, C, H, W)
|
| 266 |
+
else:
|
| 267 |
+
x = x.view(B, H, W, 4, C)
|
| 268 |
+
|
| 269 |
+
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
|
| 270 |
+
x = _fn(x, in_channel_first, out_channel_first, scans)
|
| 271 |
+
x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
|
| 272 |
+
|
| 273 |
+
return x, None, None, None, None
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# triton implements ========================================
|
| 277 |
+
|
| 278 |
+
@triton.jit
|
| 279 |
+
def triton_cross_scan_flex(
|
| 280 |
+
x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 281 |
+
y, # (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 282 |
+
x_layout: tl.constexpr,
|
| 283 |
+
y_layout: tl.constexpr,
|
| 284 |
+
operation: tl.constexpr,
|
| 285 |
+
onebyone: tl.constexpr,
|
| 286 |
+
scans: tl.constexpr,
|
| 287 |
+
BC: tl.constexpr,
|
| 288 |
+
BH: tl.constexpr,
|
| 289 |
+
BW: tl.constexpr,
|
| 290 |
+
DC: tl.constexpr,
|
| 291 |
+
DH: tl.constexpr,
|
| 292 |
+
DW: tl.constexpr,
|
| 293 |
+
NH: tl.constexpr,
|
| 294 |
+
NW: tl.constexpr,
|
| 295 |
+
):
|
| 296 |
+
# x_layout = 0
|
| 297 |
+
# y_layout = 1 # 0 BCHW, 1 BHWC
|
| 298 |
+
# operation = 0 # 0 scan, 1 merge
|
| 299 |
+
# onebyone = 0 # 0 false, 1 true
|
| 300 |
+
# scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
|
| 301 |
+
|
| 302 |
+
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 303 |
+
i_h, i_w = (i_hw // NW), (i_hw % NW)
|
| 304 |
+
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
|
| 305 |
+
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
|
| 306 |
+
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
|
| 307 |
+
_for_C = min(DC - i_c * BC, BC)
|
| 308 |
+
|
| 309 |
+
HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
|
| 310 |
+
HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans
|
| 311 |
+
HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
|
| 312 |
+
HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip
|
| 313 |
+
|
| 314 |
+
if scans == 1:
|
| 315 |
+
HWRoute1 = HWRoute0
|
| 316 |
+
HWRoute2 = HWRoute0
|
| 317 |
+
HWRoute3 = HWRoute0
|
| 318 |
+
elif scans == 2:
|
| 319 |
+
HWRoute1 = HWRoute0
|
| 320 |
+
HWRoute3 = HWRoute2
|
| 321 |
+
|
| 322 |
+
_tmp1 = DC * DH * DW
|
| 323 |
+
|
| 324 |
+
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
|
| 325 |
+
if y_layout == 0:
|
| 326 |
+
p_y1 = y_ptr_base + HWRoute0
|
| 327 |
+
p_y2 = y_ptr_base + _tmp1 + HWRoute1
|
| 328 |
+
p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
|
| 329 |
+
p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
|
| 330 |
+
else:
|
| 331 |
+
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
|
| 332 |
+
p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
|
| 333 |
+
p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
|
| 334 |
+
p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
|
| 335 |
+
|
| 336 |
+
if onebyone == 0:
|
| 337 |
+
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 338 |
+
if x_layout == 0:
|
| 339 |
+
p_x = x_ptr_base + HWRoute0
|
| 340 |
+
else:
|
| 341 |
+
p_x = x_ptr_base + HWRoute0 * DC
|
| 342 |
+
|
| 343 |
+
if operation == 0:
|
| 344 |
+
for idxc in range(_for_C):
|
| 345 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 346 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 347 |
+
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
|
| 348 |
+
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
|
| 349 |
+
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
|
| 350 |
+
tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
|
| 351 |
+
tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
|
| 352 |
+
elif operation == 1:
|
| 353 |
+
for idxc in range(_for_C):
|
| 354 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 355 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 356 |
+
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
|
| 357 |
+
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
|
| 358 |
+
_y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
|
| 359 |
+
_y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
|
| 360 |
+
tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
|
| 361 |
+
|
| 362 |
+
else:
|
| 363 |
+
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 364 |
+
if x_layout == 0:
|
| 365 |
+
p_x1 = x_ptr_base + HWRoute0
|
| 366 |
+
p_x2 = p_x1 + _tmp1
|
| 367 |
+
p_x3 = p_x2 + _tmp1
|
| 368 |
+
p_x4 = p_x3 + _tmp1
|
| 369 |
+
else:
|
| 370 |
+
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
|
| 371 |
+
p_x2 = p_x1 + DC
|
| 372 |
+
p_x3 = p_x2 + DC
|
| 373 |
+
p_x4 = p_x3 + DC
|
| 374 |
+
|
| 375 |
+
if operation == 0:
|
| 376 |
+
for idxc in range(_for_C):
|
| 377 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 378 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 379 |
+
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 380 |
+
tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 381 |
+
tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 382 |
+
tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 383 |
+
else:
|
| 384 |
+
for idxc in range(_for_C):
|
| 385 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 386 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 387 |
+
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
|
| 388 |
+
tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
|
| 389 |
+
tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
|
| 390 |
+
tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class CrossScanTritonF(torch.autograd.Function):
|
| 394 |
+
@staticmethod
|
| 395 |
+
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
|
| 396 |
+
if one_by_one:
|
| 397 |
+
if in_channel_first:
|
| 398 |
+
B, _, C, H, W = x.shape
|
| 399 |
+
else:
|
| 400 |
+
B, H, W, _, C = x.shape
|
| 401 |
+
else:
|
| 402 |
+
if in_channel_first:
|
| 403 |
+
B, C, H, W = x.shape
|
| 404 |
+
else:
|
| 405 |
+
B, H, W, C = x.shape
|
| 406 |
+
B, C, H, W = int(B), int(C), int(H), int(W)
|
| 407 |
+
BC, BH, BW = 1, 32, 32
|
| 408 |
+
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
|
| 409 |
+
|
| 410 |
+
ctx.in_channel_first = in_channel_first
|
| 411 |
+
ctx.out_channel_first = out_channel_first
|
| 412 |
+
ctx.one_by_one = one_by_one
|
| 413 |
+
ctx.scans = scans
|
| 414 |
+
ctx.shape = (B, C, H, W)
|
| 415 |
+
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
|
| 416 |
+
|
| 417 |
+
y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
|
| 418 |
+
triton_cross_scan_flex[(NH * NW, NC, B)](
|
| 419 |
+
x.contiguous(), y,
|
| 420 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
|
| 421 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 422 |
+
)
|
| 423 |
+
return y
|
| 424 |
+
|
| 425 |
+
@staticmethod
|
| 426 |
+
def backward(ctx, y: torch.Tensor):
|
| 427 |
+
in_channel_first = ctx.in_channel_first
|
| 428 |
+
out_channel_first = ctx.out_channel_first
|
| 429 |
+
one_by_one = ctx.one_by_one
|
| 430 |
+
scans = ctx.scans
|
| 431 |
+
B, C, H, W = ctx.shape
|
| 432 |
+
BC, BH, BW, NC, NH, NW = ctx.triton_shape
|
| 433 |
+
if one_by_one:
|
| 434 |
+
x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
|
| 435 |
+
else:
|
| 436 |
+
x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
|
| 437 |
+
|
| 438 |
+
triton_cross_scan_flex[(NH * NW, NC, B)](
|
| 439 |
+
x, y.contiguous(),
|
| 440 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
|
| 441 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 442 |
+
)
|
| 443 |
+
return x, None, None, None, None
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class CrossMergeTritonF(torch.autograd.Function):
|
| 447 |
+
@staticmethod
|
| 448 |
+
def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
|
| 449 |
+
if out_channel_first:
|
| 450 |
+
B, _, C, H, W = y.shape
|
| 451 |
+
else:
|
| 452 |
+
B, H, W, _, C = y.shape
|
| 453 |
+
B, C, H, W = int(B), int(C), int(H), int(W)
|
| 454 |
+
BC, BH, BW = 1, 32, 32
|
| 455 |
+
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
|
| 456 |
+
ctx.in_channel_first = in_channel_first
|
| 457 |
+
ctx.out_channel_first = out_channel_first
|
| 458 |
+
ctx.one_by_one = one_by_one
|
| 459 |
+
ctx.scans = scans
|
| 460 |
+
ctx.shape = (B, C, H, W)
|
| 461 |
+
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
|
| 462 |
+
if one_by_one:
|
| 463 |
+
x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
|
| 464 |
+
else:
|
| 465 |
+
x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
|
| 466 |
+
triton_cross_scan_flex[(NH * NW, NC, B)](
|
| 467 |
+
x, y.contiguous(),
|
| 468 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
|
| 469 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 470 |
+
)
|
| 471 |
+
return x
|
| 472 |
+
|
| 473 |
+
@staticmethod
|
| 474 |
+
def backward(ctx, x: torch.Tensor):
|
| 475 |
+
in_channel_first = ctx.in_channel_first
|
| 476 |
+
out_channel_first = ctx.out_channel_first
|
| 477 |
+
one_by_one = ctx.one_by_one
|
| 478 |
+
scans = ctx.scans
|
| 479 |
+
B, C, H, W = ctx.shape
|
| 480 |
+
BC, BH, BW, NC, NH, NW = ctx.triton_shape
|
| 481 |
+
y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
|
| 482 |
+
triton_cross_scan_flex[(NH * NW, NC, B)](
|
| 483 |
+
x.contiguous(), y,
|
| 484 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
|
| 485 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 486 |
+
)
|
| 487 |
+
return y, None, None, None, None, None
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
| 491 |
+
def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
|
| 492 |
+
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 493 |
+
# y: (B, 4, C, L) | (B, L, 4, C)
|
| 494 |
+
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
|
| 495 |
+
CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
|
| 496 |
+
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
| 500 |
+
def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
|
| 501 |
+
# y: (B, 4, C, L) | (B, L, 4, C)
|
| 502 |
+
# x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
|
| 503 |
+
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
|
| 504 |
+
CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
|
| 505 |
+
return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# checks =================================================================
|
| 509 |
+
|
| 510 |
+
class CHECK:
|
| 511 |
+
def check_csm_triton():
|
| 512 |
+
B, C, H, W = 2, 192, 56, 57
|
| 513 |
+
dtype=torch.float16
|
| 514 |
+
dtype=torch.float32
|
| 515 |
+
x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
|
| 516 |
+
y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
|
| 517 |
+
x1 = x.clone().detach().requires_grad_(True)
|
| 518 |
+
y1 = y.clone().detach().requires_grad_(True)
|
| 519 |
+
|
| 520 |
+
def cross_scan(x: torch.Tensor):
|
| 521 |
+
B, C, H, W = x.shape
|
| 522 |
+
L = H * W
|
| 523 |
+
xs = torch.stack([
|
| 524 |
+
x.view(B, C, L),
|
| 525 |
+
torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
|
| 526 |
+
torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
|
| 527 |
+
torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
|
| 528 |
+
], dim=1).view(B, 4, C, L)
|
| 529 |
+
return xs
|
| 530 |
+
|
| 531 |
+
def cross_merge(out_y: torch.Tensor):
|
| 532 |
+
B, K, D, H, W = out_y.shape
|
| 533 |
+
L = H * W
|
| 534 |
+
out_y = out_y.view(B, K, D, L)
|
| 535 |
+
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
| 536 |
+
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 537 |
+
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 538 |
+
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
|
| 539 |
+
return y
|
| 540 |
+
|
| 541 |
+
def cross_scan_1b1(x: torch.Tensor):
|
| 542 |
+
B, K, C, H, W = x.shape
|
| 543 |
+
L = H * W
|
| 544 |
+
xs = torch.stack([
|
| 545 |
+
x[:, 0].view(B, C, L),
|
| 546 |
+
torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
|
| 547 |
+
torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
|
| 548 |
+
torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
|
| 549 |
+
], dim=1).view(B, 4, C, L)
|
| 550 |
+
return xs
|
| 551 |
+
|
| 552 |
+
def unidi_scan(x):
|
| 553 |
+
B, C, H, W = x.shape
|
| 554 |
+
x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
|
| 555 |
+
return x
|
| 556 |
+
|
| 557 |
+
def unidi_merge(ys):
|
| 558 |
+
B, K, C, H, W = ys.shape
|
| 559 |
+
return ys.view(B, 4, -1, H * W).sum(1)
|
| 560 |
+
|
| 561 |
+
def bidi_scan(x):
|
| 562 |
+
B, C, H, W = x.shape
|
| 563 |
+
x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
|
| 564 |
+
x = torch.cat([x, x.flip(dims=[-1])], dim=1)
|
| 565 |
+
return x
|
| 566 |
+
|
| 567 |
+
def bidi_merge(ys):
|
| 568 |
+
B, K, D, H, W = ys.shape
|
| 569 |
+
ys = ys.view(B, K, D, -1)
|
| 570 |
+
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
| 571 |
+
return ys.contiguous().sum(1)
|
| 572 |
+
|
| 573 |
+
if True:
|
| 574 |
+
res0 = triton.testing.do_bench(lambda :cross_scan(x))
|
| 575 |
+
res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False))
|
| 576 |
+
# res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
|
| 577 |
+
res3 = triton.testing.do_bench(lambda :cross_merge(y))
|
| 578 |
+
res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False))
|
| 579 |
+
# res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
|
| 580 |
+
# print(res0, res1, res2, res3, res4, res5)
|
| 581 |
+
print(res0, res1, res3, res4)
|
| 582 |
+
res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
|
| 583 |
+
res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward())
|
| 584 |
+
# res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
|
| 585 |
+
res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
|
| 586 |
+
res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward())
|
| 587 |
+
# res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
|
| 588 |
+
# print(res0, res1, res2, res3, res4, res5)
|
| 589 |
+
print(res0, res1, res3, res4)
|
| 590 |
+
|
| 591 |
+
print("test cross scan")
|
| 592 |
+
for (cs0, cm0, cs1, cm1) in [
|
| 593 |
+
# channel_first -> channel_first
|
| 594 |
+
(cross_scan, cross_merge, cross_scan_fn, cross_merge_fn),
|
| 595 |
+
(unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)),
|
| 596 |
+
(bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)),
|
| 597 |
+
|
| 598 |
+
# flex: BLC->BCL; BCL->BLC; BLC->BLC;
|
| 599 |
+
(cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)),
|
| 600 |
+
(cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
|
| 601 |
+
(cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
|
| 602 |
+
|
| 603 |
+
# previous
|
| 604 |
+
# (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
|
| 605 |
+
# (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
|
| 606 |
+
# (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
|
| 607 |
+
]:
|
| 608 |
+
x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
|
| 609 |
+
o0 = cs0(x)
|
| 610 |
+
o1 = cs1(x1)
|
| 611 |
+
o0.backward(y.view(B, 4, C, H * W))
|
| 612 |
+
o1.backward(y.view(B, 4, C, H * W))
|
| 613 |
+
print((o0 - o1).abs().max())
|
| 614 |
+
print((x.grad - x1.grad).abs().max())
|
| 615 |
+
o0 = cm0(y)
|
| 616 |
+
o1 = cm1(y1)
|
| 617 |
+
o0.backward(x.view(B, C, H * W))
|
| 618 |
+
o1.backward(x.view(B, C, H * W))
|
| 619 |
+
print((o0 - o1).abs().max())
|
| 620 |
+
print((y.grad - y1.grad).abs().max())
|
| 621 |
+
x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
|
| 622 |
+
print("===============", flush=True)
|
| 623 |
+
|
| 624 |
+
print("test cross scan one by one")
|
| 625 |
+
for (cs0, cs1) in [
|
| 626 |
+
(cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)),
|
| 627 |
+
# (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
|
| 628 |
+
]:
|
| 629 |
+
o0 = cs0(y)
|
| 630 |
+
o1 = cs1(y1)
|
| 631 |
+
o0.backward(y.view(B, 4, C, H * W))
|
| 632 |
+
o1.backward(y.view(B, 4, C, H * W))
|
| 633 |
+
print((o0 - o1).abs().max())
|
| 634 |
+
print((y.grad - y1.grad).abs().max())
|
| 635 |
+
x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
|
| 636 |
+
print("===============", flush=True)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
if __name__ == "__main__":
|
| 640 |
+
CHECK.check_csm_triton()
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
|
rscd/models/backbones/lib_mamba/csm_tritonk2.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import warnings
|
| 3 |
+
import os
|
| 4 |
+
os.environ["TRITON_INTERPRET"] = "1"
|
| 5 |
+
|
| 6 |
+
WITH_TRITON = True
|
| 7 |
+
# WITH_TRITON = False
|
| 8 |
+
try:
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
except:
|
| 12 |
+
WITH_TRITON = False
|
| 13 |
+
warnings.warn("Triton not installed, fall back to pytorch implements.")
|
| 14 |
+
|
| 15 |
+
# to make sure cached_property can be loaded for triton
|
| 16 |
+
if WITH_TRITON:
|
| 17 |
+
try:
|
| 18 |
+
from functools import cached_property
|
| 19 |
+
except:
|
| 20 |
+
warnings.warn("if you are using py37, add this line to functools.py: "
|
| 21 |
+
"cached_property = lambda func: property(lru_cache()(func))")
|
| 22 |
+
|
| 23 |
+
# torch implementation ========================================
|
| 24 |
+
def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
|
| 25 |
+
if in_channel_first:
|
| 26 |
+
B, C, H, W = x.shape
|
| 27 |
+
if scans == 0:
|
| 28 |
+
y = x.new_empty((B, 4, C, H * W))
|
| 29 |
+
y[:, 0, :, :] = x.flatten(2, 3)
|
| 30 |
+
y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
|
| 31 |
+
y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
|
| 32 |
+
elif scans == 1:
|
| 33 |
+
y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
|
| 34 |
+
elif scans == 2:
|
| 35 |
+
y = x.view(B, 1, C, H * W)
|
| 36 |
+
y = torch.cat([y, y.flip(dims=[-1])], dim=1)
|
| 37 |
+
else:
|
| 38 |
+
B, H, W, C = x.shape
|
| 39 |
+
if scans == 0:
|
| 40 |
+
y = x.new_empty((B, H * W, 4, C))
|
| 41 |
+
y[:, :, 0, :] = x.flatten(1, 2)
|
| 42 |
+
y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
|
| 43 |
+
y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
|
| 44 |
+
elif scans == 1:
|
| 45 |
+
y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
|
| 46 |
+
elif scans == 2:
|
| 47 |
+
y = x.view(B, H * W, 1, C)
|
| 48 |
+
y = torch.cat([y, y.flip(dims=[1])], dim=2)
|
| 49 |
+
|
| 50 |
+
if in_channel_first and (not out_channel_first):
|
| 51 |
+
y = y.permute(0, 3, 1, 2).contiguous()
|
| 52 |
+
elif (not in_channel_first) and out_channel_first:
|
| 53 |
+
y = y.permute(0, 2, 3, 1).contiguous()
|
| 54 |
+
|
| 55 |
+
return y
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
|
| 59 |
+
if out_channel_first:
|
| 60 |
+
B, K, D, H, W = y.shape
|
| 61 |
+
y = y.view(B, K, D, -1)
|
| 62 |
+
if scans == 0:
|
| 63 |
+
y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
| 64 |
+
y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
|
| 65 |
+
elif scans == 1:
|
| 66 |
+
y = y.sum(1)
|
| 67 |
+
elif scans == 2:
|
| 68 |
+
y = y[:, 0] + y[:, 1].flip(dims=[-1]).view(B, 1, D, -1)
|
| 69 |
+
y = y.sum(1)
|
| 70 |
+
else:
|
| 71 |
+
B, H, W, K, D = y.shape
|
| 72 |
+
y = y.view(B, -1, K, D)
|
| 73 |
+
if scans == 0:
|
| 74 |
+
y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
|
| 75 |
+
y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
|
| 76 |
+
elif scans == 1:
|
| 77 |
+
y = y.sum(2)
|
| 78 |
+
elif scans == 2:
|
| 79 |
+
y = y[:, :, 0] + y[:, :, 1].flip(dims=[1]).view(B, -1, 1, D)
|
| 80 |
+
y = y.sum(2)
|
| 81 |
+
|
| 82 |
+
if in_channel_first and (not out_channel_first):
|
| 83 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 84 |
+
elif (not in_channel_first) and out_channel_first:
|
| 85 |
+
y = y.permute(0, 2, 1).contiguous()
|
| 86 |
+
|
| 87 |
+
return y
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
|
| 91 |
+
if in_channel_first:
|
| 92 |
+
B, _, C, H, W = x.shape
|
| 93 |
+
if scans == 0:
|
| 94 |
+
y = torch.stack([
|
| 95 |
+
x[:, 0].flatten(2, 3),
|
| 96 |
+
x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
|
| 97 |
+
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
|
| 98 |
+
torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
|
| 99 |
+
], dim=1)
|
| 100 |
+
elif scans == 1:
|
| 101 |
+
y = x.flatten(2, 3)
|
| 102 |
+
elif scans == 2:
|
| 103 |
+
y = torch.stack([
|
| 104 |
+
x[:, 0].flatten(2, 3),
|
| 105 |
+
x[:, 1].flatten(2, 3),
|
| 106 |
+
torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
|
| 107 |
+
torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
|
| 108 |
+
], dim=1)
|
| 109 |
+
else:
|
| 110 |
+
B, H, W, _, C = x.shape
|
| 111 |
+
if scans == 0:
|
| 112 |
+
y = torch.stack([
|
| 113 |
+
x[:, :, :, 0].flatten(1, 2),
|
| 114 |
+
x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
|
| 115 |
+
torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
|
| 116 |
+
torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
|
| 117 |
+
], dim=2)
|
| 118 |
+
elif scans == 1:
|
| 119 |
+
y = x.flatten(1, 2)
|
| 120 |
+
elif scans == 2:
|
| 121 |
+
y = torch.stack([
|
| 122 |
+
x[:, 0].flatten(1, 2),
|
| 123 |
+
x[:, 1].flatten(1, 2),
|
| 124 |
+
torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
|
| 125 |
+
torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
|
| 126 |
+
], dim=2)
|
| 127 |
+
|
| 128 |
+
if in_channel_first and (not out_channel_first):
|
| 129 |
+
y = y.permute(0, 3, 1, 2).contiguous()
|
| 130 |
+
elif (not in_channel_first) and out_channel_first:
|
| 131 |
+
y = y.permute(0, 2, 3, 1).contiguous()
|
| 132 |
+
|
| 133 |
+
return y
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2):
|
| 137 |
+
if out_channel_first:
|
| 138 |
+
B, K, D, H, W = y.shape
|
| 139 |
+
y = y.view(B, K, D, -1)
|
| 140 |
+
if scans == 0:
|
| 141 |
+
y = torch.stack([
|
| 142 |
+
y[:, 0],
|
| 143 |
+
y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
|
| 144 |
+
torch.flip(y[:, 2], dims=[-1]),
|
| 145 |
+
torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
|
| 146 |
+
], dim=1)
|
| 147 |
+
elif scans == 1:
|
| 148 |
+
y = y
|
| 149 |
+
elif scans == 2:
|
| 150 |
+
y = torch.stack([
|
| 151 |
+
y[:, 0],
|
| 152 |
+
y[:, 1],
|
| 153 |
+
torch.flip(y[:, 2], dims=[-1]),
|
| 154 |
+
torch.flip(y[:, 3], dims=[-1]),
|
| 155 |
+
], dim=1)
|
| 156 |
+
else:
|
| 157 |
+
B, H, W, _, D = y.shape
|
| 158 |
+
y = y.view(B, -1, 2, D)
|
| 159 |
+
if scans == 0:
|
| 160 |
+
y = torch.stack([
|
| 161 |
+
y[:, :, 0],
|
| 162 |
+
y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
|
| 163 |
+
torch.flip(y[:, :, 2], dims=[1]),
|
| 164 |
+
torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
|
| 165 |
+
], dim=2)
|
| 166 |
+
elif scans == 1:
|
| 167 |
+
y = y
|
| 168 |
+
elif scans == 2:
|
| 169 |
+
y = torch.stack([
|
| 170 |
+
y[:, :, 0],
|
| 171 |
+
y[:, :, 1],
|
| 172 |
+
torch.flip(y[:, :, 2], dims=[1]),
|
| 173 |
+
torch.flip(y[:, :, 3], dims=[1]),
|
| 174 |
+
], dim=2)
|
| 175 |
+
|
| 176 |
+
if out_channel_first and (not in_channel_first):
|
| 177 |
+
y = y.permute(0, 3, 1, 2).contiguous()
|
| 178 |
+
elif (not out_channel_first) and in_channel_first:
|
| 179 |
+
y = y.permute(0, 2, 3, 1).contiguous()
|
| 180 |
+
|
| 181 |
+
return y
|
| 182 |
+
|
| 183 |
+
class CrossScan(torch.nn.Module):
|
| 184 |
+
def __init__(self, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
|
| 185 |
+
super(CrossScan, self).__init__()
|
| 186 |
+
self.in_channel_first = in_channel_first
|
| 187 |
+
self.out_channel_first = out_channel_first
|
| 188 |
+
self.one_by_one = one_by_one
|
| 189 |
+
self.scans = scans
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor):
|
| 192 |
+
if self.one_by_one:
|
| 193 |
+
B, K, C, H, W = x.shape
|
| 194 |
+
if not self.in_channel_first:
|
| 195 |
+
B, H, W, K, C = x.shape
|
| 196 |
+
else:
|
| 197 |
+
B, C, H, W = x.shape
|
| 198 |
+
if not self.in_channel_first:
|
| 199 |
+
B, H, W, C = x.shape
|
| 200 |
+
self.shape = (B, C, H, W)
|
| 201 |
+
|
| 202 |
+
_fn = cross_scan1b1_fwd if self.one_by_one else cross_scan_fwd
|
| 203 |
+
y = _fn(x, self.in_channel_first, self.out_channel_first, self.scans)
|
| 204 |
+
|
| 205 |
+
return y
|
| 206 |
+
|
| 207 |
+
def backward(self, ys: torch.Tensor):
|
| 208 |
+
B, C, H, W = self.shape
|
| 209 |
+
|
| 210 |
+
ys = ys.view(B, -1, C, H, W) if self.out_channel_first else ys.view(B, H, W, -1, C)
|
| 211 |
+
_fn = cross_merge1b1_fwd if self.one_by_one else cross_merge_fwd
|
| 212 |
+
y = _fn(ys, self.in_channel_first, self.out_channel_first, self.scans)
|
| 213 |
+
|
| 214 |
+
if self.one_by_one:
|
| 215 |
+
y = y.view(B, 2, -1, H, W) if self.in_channel_first else y.view(B, H, W, 2, -1)
|
| 216 |
+
else:
|
| 217 |
+
y = y.view(B, -1, H, W) if self.in_channel_first else y.view(B, H, W, -1)
|
| 218 |
+
|
| 219 |
+
return y
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class CrossMerge(torch.nn.Module):
|
| 223 |
+
def __init__(self, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
|
| 224 |
+
super(CrossMerge, self).__init__()
|
| 225 |
+
self.in_channel_first = in_channel_first
|
| 226 |
+
self.out_channel_first = out_channel_first
|
| 227 |
+
self.one_by_one = one_by_one
|
| 228 |
+
self.scans = scans
|
| 229 |
+
|
| 230 |
+
def forward(self, ys: torch.Tensor):
|
| 231 |
+
B, K, C, H, W = ys.shape
|
| 232 |
+
if not self.out_channel_first:
|
| 233 |
+
B, H, W, K, C = ys.shape
|
| 234 |
+
self.shape = (B, C, H, W)
|
| 235 |
+
|
| 236 |
+
_fn = cross_merge1b1_fwd if self.one_by_one else cross_merge_fwd
|
| 237 |
+
y = _fn(ys, self.in_channel_first, self.out_channel_first, self.scans)
|
| 238 |
+
|
| 239 |
+
return y
|
| 240 |
+
|
| 241 |
+
def backward(self, x: torch.Tensor):
|
| 242 |
+
B, C, H, W = self.shape
|
| 243 |
+
|
| 244 |
+
if not self.one_by_one:
|
| 245 |
+
if self.in_channel_first:
|
| 246 |
+
x = x.view(B, C, H, W)
|
| 247 |
+
else:
|
| 248 |
+
x = x.view(B, H, W, C)
|
| 249 |
+
else:
|
| 250 |
+
if self.in_channel_first:
|
| 251 |
+
x = x.view(B, 2, C, H, W)
|
| 252 |
+
else:
|
| 253 |
+
x = x.view(B, H, W, 2, C)
|
| 254 |
+
|
| 255 |
+
_fn = cross_scan1b1_fwd if self.one_by_one else cross_scan_fwd
|
| 256 |
+
x = _fn(x, self.in_channel_first, self.out_channel_first, self.scans)
|
| 257 |
+
x = x.view(B, 2, C, H, W) if self.out_channel_first else x.view(B, H, W, 2, C)
|
| 258 |
+
|
| 259 |
+
return x
|
| 260 |
+
class CrossScanF(torch.autograd.Function):
|
| 261 |
+
@staticmethod
|
| 262 |
+
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
|
| 263 |
+
# x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 2, C)
|
| 264 |
+
# y: (B, 2, C, H * W) | (B, H * W, 2, C)
|
| 265 |
+
ctx.in_channel_first = in_channel_first
|
| 266 |
+
ctx.out_channel_first = out_channel_first
|
| 267 |
+
ctx.one_by_one = one_by_one
|
| 268 |
+
ctx.scans = scans
|
| 269 |
+
|
| 270 |
+
if one_by_one:
|
| 271 |
+
B, K, C, H, W = x.shape
|
| 272 |
+
if not in_channel_first:
|
| 273 |
+
B, H, W, K, C = x.shape
|
| 274 |
+
else:
|
| 275 |
+
B, C, H, W = x.shape
|
| 276 |
+
if not in_channel_first:
|
| 277 |
+
B, H, W, C = x.shape
|
| 278 |
+
ctx.shape = (B, C, H, W)
|
| 279 |
+
|
| 280 |
+
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
|
| 281 |
+
y = _fn(x, in_channel_first, out_channel_first, scans)
|
| 282 |
+
|
| 283 |
+
return y
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def backward(ctx, ys: torch.Tensor):
|
| 287 |
+
# out: (b, k, d, l)
|
| 288 |
+
in_channel_first = ctx.in_channel_first
|
| 289 |
+
out_channel_first = ctx.out_channel_first
|
| 290 |
+
one_by_one = ctx.one_by_one
|
| 291 |
+
scans = ctx.scans
|
| 292 |
+
B, C, H, W = ctx.shape
|
| 293 |
+
|
| 294 |
+
ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
|
| 295 |
+
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
|
| 296 |
+
y = _fn(ys, in_channel_first, out_channel_first, scans)
|
| 297 |
+
|
| 298 |
+
if one_by_one:
|
| 299 |
+
y = y.view(B, 2, -1, H, W) if in_channel_first else y.view(B, H, W, 2, -1)
|
| 300 |
+
else:
|
| 301 |
+
y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
|
| 302 |
+
|
| 303 |
+
return y, None, None, None, None
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class CrossMergeF(torch.autograd.Function):
|
| 307 |
+
@staticmethod
|
| 308 |
+
def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
|
| 309 |
+
# x: (B, C, H, W) | (B, H, W, C) | (B, 2, C, H, W) | (B, H, W, 2, C)
|
| 310 |
+
# y: (B, 2, C, H * W) | (B, H * W, 4, C)
|
| 311 |
+
ctx.in_channel_first = in_channel_first
|
| 312 |
+
ctx.out_channel_first = out_channel_first
|
| 313 |
+
ctx.one_by_one = one_by_one
|
| 314 |
+
ctx.scans = scans
|
| 315 |
+
|
| 316 |
+
B, K, C, H, W = ys.shape
|
| 317 |
+
if not out_channel_first:
|
| 318 |
+
B, H, W, K, C = ys.shape
|
| 319 |
+
ctx.shape = (B, C, H, W)
|
| 320 |
+
|
| 321 |
+
_fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
|
| 322 |
+
y = _fn(ys, in_channel_first, out_channel_first, scans)
|
| 323 |
+
|
| 324 |
+
return y
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def backward(ctx, x: torch.Tensor):
|
| 328 |
+
# B, D, L = x.shape
|
| 329 |
+
# out: (b, k, d, h, w)
|
| 330 |
+
in_channel_first = ctx.in_channel_first
|
| 331 |
+
out_channel_first = ctx.out_channel_first
|
| 332 |
+
one_by_one = ctx.one_by_one
|
| 333 |
+
scans = ctx.scans
|
| 334 |
+
B, C, H, W = ctx.shape
|
| 335 |
+
|
| 336 |
+
if not one_by_one:
|
| 337 |
+
if in_channel_first:
|
| 338 |
+
x = x.view(B, C, H, W)
|
| 339 |
+
else:
|
| 340 |
+
x = x.view(B, H, W, C)
|
| 341 |
+
else:
|
| 342 |
+
if in_channel_first:
|
| 343 |
+
x = x.view(B, 2, C, H, W)
|
| 344 |
+
else:
|
| 345 |
+
x = x.view(B, H, W, 2, C)
|
| 346 |
+
|
| 347 |
+
_fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
|
| 348 |
+
x = _fn(x, in_channel_first, out_channel_first, scans)
|
| 349 |
+
x = x.view(B, 2, C, H, W) if out_channel_first else x.view(B, H, W, 2, C)
|
| 350 |
+
|
| 351 |
+
return x, None, None, None, None
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# triton implements ========================================
|
| 355 |
+
|
| 356 |
+
@triton.jit
|
| 357 |
+
def triton_cross_scan_flex_k2(
|
| 358 |
+
x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 359 |
+
y, # (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 360 |
+
x_layout: tl.constexpr,
|
| 361 |
+
y_layout: tl.constexpr,
|
| 362 |
+
operation: tl.constexpr,
|
| 363 |
+
onebyone: tl.constexpr,
|
| 364 |
+
scans: tl.constexpr,
|
| 365 |
+
BC: tl.constexpr,
|
| 366 |
+
BH: tl.constexpr,
|
| 367 |
+
BW: tl.constexpr,
|
| 368 |
+
DC: tl.constexpr,
|
| 369 |
+
DH: tl.constexpr,
|
| 370 |
+
DW: tl.constexpr,
|
| 371 |
+
NH: tl.constexpr,
|
| 372 |
+
NW: tl.constexpr,
|
| 373 |
+
):
|
| 374 |
+
# x_layout = 0
|
| 375 |
+
# y_layout = 1 # 0 BCHW, 1 BHWC
|
| 376 |
+
# operation = 0 # 0 scan, 1 merge
|
| 377 |
+
# onebyone = 0 # 0 false, 1 true
|
| 378 |
+
# scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
|
| 379 |
+
|
| 380 |
+
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 381 |
+
i_h, i_w = (i_hw // NW), (i_hw % NW)
|
| 382 |
+
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
|
| 383 |
+
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
|
| 384 |
+
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
|
| 385 |
+
_for_C = min(DC - i_c * BC, BC)
|
| 386 |
+
|
| 387 |
+
HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
|
| 388 |
+
# HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans
|
| 389 |
+
HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
|
| 390 |
+
# HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip
|
| 391 |
+
|
| 392 |
+
if scans == 1:
|
| 393 |
+
HWRoute2 = HWRoute0
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
_tmp1 = DC * DH * DW
|
| 397 |
+
|
| 398 |
+
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
|
| 399 |
+
if y_layout == 0:
|
| 400 |
+
p_y1 = y_ptr_base + HWRoute0
|
| 401 |
+
# p_y2 = y_ptr_base + _tmp1 + HWRoute1
|
| 402 |
+
p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
|
| 403 |
+
# p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
|
| 404 |
+
else:
|
| 405 |
+
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
|
| 406 |
+
# p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
|
| 407 |
+
p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
|
| 408 |
+
# p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
|
| 409 |
+
|
| 410 |
+
if onebyone == 0:
|
| 411 |
+
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 412 |
+
if x_layout == 0:
|
| 413 |
+
p_x = x_ptr_base + HWRoute0
|
| 414 |
+
else:
|
| 415 |
+
p_x = x_ptr_base + HWRoute0 * DC
|
| 416 |
+
|
| 417 |
+
if operation == 0:
|
| 418 |
+
for idxc in range(_for_C):
|
| 419 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 420 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 421 |
+
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
|
| 422 |
+
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
|
| 423 |
+
# tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
|
| 424 |
+
tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
|
| 425 |
+
# tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
|
| 426 |
+
elif operation == 1:
|
| 427 |
+
for idxc in range(_for_C):
|
| 428 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 429 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 430 |
+
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
|
| 431 |
+
# _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
|
| 432 |
+
_y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
|
| 433 |
+
# _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
|
| 434 |
+
# tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
|
| 435 |
+
tl.store(p_x + _idx_x, _y1 + _y3, mask=_mask_hw)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
else:
|
| 439 |
+
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 440 |
+
if x_layout == 0:
|
| 441 |
+
p_x1 = x_ptr_base + HWRoute0
|
| 442 |
+
p_x2 = p_x1 + _tmp1
|
| 443 |
+
p_x3 = p_x2 + _tmp1
|
| 444 |
+
p_x4 = p_x3 + _tmp1
|
| 445 |
+
else:
|
| 446 |
+
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
|
| 447 |
+
p_x2 = p_x1 + DC
|
| 448 |
+
p_x3 = p_x2 + DC
|
| 449 |
+
p_x4 = p_x3 + DC
|
| 450 |
+
|
| 451 |
+
if operation == 0:
|
| 452 |
+
for idxc in range(_for_C):
|
| 453 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 454 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 455 |
+
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 456 |
+
# tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 457 |
+
tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 458 |
+
# tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 459 |
+
else:
|
| 460 |
+
for idxc in range(_for_C):
|
| 461 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 462 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 463 |
+
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
|
| 464 |
+
# tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
|
| 465 |
+
tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
|
| 466 |
+
# tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
|
| 467 |
+
|
| 468 |
+
@triton.jit
|
| 469 |
+
def triton_cross_scan_flex_k2(
|
| 470 |
+
x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 471 |
+
y, # (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 472 |
+
x_layout: tl.constexpr,
|
| 473 |
+
y_layout: tl.constexpr,
|
| 474 |
+
operation: tl.constexpr,
|
| 475 |
+
onebyone: tl.constexpr,
|
| 476 |
+
scans: tl.constexpr,
|
| 477 |
+
BC: tl.constexpr,
|
| 478 |
+
BH: tl.constexpr,
|
| 479 |
+
BW: tl.constexpr,
|
| 480 |
+
DC: tl.constexpr,
|
| 481 |
+
DH: tl.constexpr,
|
| 482 |
+
DW: tl.constexpr,
|
| 483 |
+
NH: tl.constexpr,
|
| 484 |
+
NW: tl.constexpr,
|
| 485 |
+
):
|
| 486 |
+
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 487 |
+
i_h, i_w = (i_hw // NW), (i_hw % NW)
|
| 488 |
+
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
|
| 489 |
+
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
|
| 490 |
+
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
|
| 491 |
+
_for_C = min(DC - i_c * BC, BC)
|
| 492 |
+
|
| 493 |
+
HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
|
| 494 |
+
HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
|
| 495 |
+
|
| 496 |
+
if scans == 1:
|
| 497 |
+
HWRoute2 = HWRoute0
|
| 498 |
+
|
| 499 |
+
_tmp1 = DC * DH * DW
|
| 500 |
+
|
| 501 |
+
y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
|
| 502 |
+
if y_layout == 0:
|
| 503 |
+
p_y1 = y_ptr_base + HWRoute0
|
| 504 |
+
p_y2 = y_ptr_base + 2 * _tmp1 + HWRoute2
|
| 505 |
+
else:
|
| 506 |
+
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
|
| 507 |
+
p_y2 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
|
| 508 |
+
|
| 509 |
+
if onebyone == 0:
|
| 510 |
+
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 511 |
+
if x_layout == 0:
|
| 512 |
+
p_x = x_ptr_base + HWRoute0
|
| 513 |
+
else:
|
| 514 |
+
p_x = x_ptr_base + HWRoute0 * DC
|
| 515 |
+
|
| 516 |
+
if operation == 0:
|
| 517 |
+
for idxc in range(_for_C):
|
| 518 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 519 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 520 |
+
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
|
| 521 |
+
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
|
| 522 |
+
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
|
| 523 |
+
elif operation == 1:
|
| 524 |
+
for idxc in range(_for_C):
|
| 525 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 526 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 527 |
+
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
|
| 528 |
+
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
|
| 529 |
+
tl.store(p_x + _idx_x, _y1 + _y2, mask=_mask_hw)
|
| 530 |
+
|
| 531 |
+
else:
|
| 532 |
+
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 533 |
+
if x_layout == 0:
|
| 534 |
+
p_x1 = x_ptr_base + HWRoute0
|
| 535 |
+
p_x2 = p_x1 + 2 * _tmp1
|
| 536 |
+
else:
|
| 537 |
+
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
|
| 538 |
+
p_x2 = p_x1 + 2 * DC
|
| 539 |
+
|
| 540 |
+
if operation == 0:
|
| 541 |
+
for idxc in range(_for_C):
|
| 542 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 543 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 544 |
+
tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 545 |
+
tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
|
| 546 |
+
else:
|
| 547 |
+
for idxc in range(_for_C):
|
| 548 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 549 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 550 |
+
tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
|
| 551 |
+
tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
|
| 552 |
+
|
| 553 |
+
@triton.jit
|
| 554 |
+
def triton_cross_scan_flex_k2(
|
| 555 |
+
x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 556 |
+
y, # (B, 4, C, H, W) | (B, H, W, 4, C)
|
| 557 |
+
x_layout: tl.constexpr,
|
| 558 |
+
y_layout: tl.constexpr,
|
| 559 |
+
operation: tl.constexpr,
|
| 560 |
+
onebyone: tl.constexpr,
|
| 561 |
+
scans: tl.constexpr,
|
| 562 |
+
BC: tl.constexpr,
|
| 563 |
+
BH: tl.constexpr,
|
| 564 |
+
BW: tl.constexpr,
|
| 565 |
+
DC: tl.constexpr,
|
| 566 |
+
DH: tl.constexpr,
|
| 567 |
+
DW: tl.constexpr,
|
| 568 |
+
NH: tl.constexpr,
|
| 569 |
+
NW: tl.constexpr,
|
| 570 |
+
):
|
| 571 |
+
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 572 |
+
i_h, i_w = (i_hw // NW), (i_hw % NW)
|
| 573 |
+
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
|
| 574 |
+
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
|
| 575 |
+
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
|
| 576 |
+
_for_C = min(DC - i_c * BC, BC)
|
| 577 |
+
|
| 578 |
+
HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
|
| 579 |
+
HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
|
| 580 |
+
|
| 581 |
+
if scans == 1:
|
| 582 |
+
HWRoute2 = HWRoute0
|
| 583 |
+
|
| 584 |
+
_tmp1 = DC * DH * DW
|
| 585 |
+
|
| 586 |
+
y_ptr_base = y + i_b * 2 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
|
| 587 |
+
if y_layout == 0:
|
| 588 |
+
p_y1 = y_ptr_base + HWRoute0
|
| 589 |
+
p_y2 = y_ptr_base + 1 * _tmp1 + HWRoute2
|
| 590 |
+
else:
|
| 591 |
+
p_y1 = y_ptr_base + HWRoute0 * 4 * DC
|
| 592 |
+
p_y2 = y_ptr_base + 1 * DC + HWRoute2 * 4 * DC
|
| 593 |
+
|
| 594 |
+
if onebyone == 0:
|
| 595 |
+
x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 596 |
+
if x_layout == 0:
|
| 597 |
+
p_x = x_ptr_base + HWRoute0
|
| 598 |
+
else:
|
| 599 |
+
p_x = x_ptr_base + HWRoute0 * DC
|
| 600 |
+
|
| 601 |
+
if operation == 0:
|
| 602 |
+
for idxc in range(_for_C):
|
| 603 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 604 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 605 |
+
_x = tl.load(p_x + _idx_x, mask=_mask_hw)
|
| 606 |
+
tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
|
| 607 |
+
tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
|
| 608 |
+
elif operation == 1:
|
| 609 |
+
for idxc in range(_for_C):
|
| 610 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 611 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 612 |
+
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
|
| 613 |
+
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
|
| 614 |
+
tl.store(p_x + _idx_x, _y1 + _y2, mask=_mask_hw)
|
| 615 |
+
|
| 616 |
+
else:
|
| 617 |
+
x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
|
| 618 |
+
if x_layout == 0:
|
| 619 |
+
p_x1 = x_ptr_base + HWRoute0
|
| 620 |
+
p_x2 = p_x1 + 2 * _tmp1
|
| 621 |
+
else:
|
| 622 |
+
p_x1 = x_ptr_base + HWRoute0 * 4 * DC
|
| 623 |
+
p_x2 = p_x1 + 2 * DC
|
| 624 |
+
|
| 625 |
+
if operation == 0:
|
| 626 |
+
for idxc in range(_for_C):
|
| 627 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 628 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 629 |
+
_x1 = tl.load(p_x1 + _idx_x, mask=_mask_hw)
|
| 630 |
+
_x2 = tl.load(p_x2 + _idx_x, mask=_mask_hw)
|
| 631 |
+
tl.store(p_y1 + _idx_y, _x1, mask=_mask_hw)
|
| 632 |
+
tl.store(p_y2 + _idx_y, _x2, mask=_mask_hw)
|
| 633 |
+
else:
|
| 634 |
+
for idxc in range(_for_C):
|
| 635 |
+
_idx_x = idxc * DH * DW if x_layout == 0 else idxc
|
| 636 |
+
_idx_y = idxc * DH * DW if y_layout == 0 else idxc
|
| 637 |
+
_y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
|
| 638 |
+
_y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
|
| 639 |
+
tl.store(p_x1 + _idx_x, _y1, mask=_mask_hw)
|
| 640 |
+
tl.store(p_x2 + _idx_x, _y2, mask=_mask_hw)
|
| 641 |
+
|
| 642 |
+
class CrossScanTritonFk2(torch.autograd.Function):
|
| 643 |
+
@staticmethod
|
| 644 |
+
def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
|
| 645 |
+
if one_by_one:
|
| 646 |
+
if in_channel_first:
|
| 647 |
+
B, _, C, H, W = x.shape
|
| 648 |
+
else:
|
| 649 |
+
B, H, W, _, C = x.shape
|
| 650 |
+
else:
|
| 651 |
+
if in_channel_first:
|
| 652 |
+
B, C, H, W = x.shape
|
| 653 |
+
else:
|
| 654 |
+
B, H, W, C = x.shape
|
| 655 |
+
B, C, H, W = int(B), int(C), int(H), int(W)
|
| 656 |
+
BC, BH, BW = 1, 32, 32
|
| 657 |
+
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
|
| 658 |
+
|
| 659 |
+
ctx.in_channel_first = in_channel_first
|
| 660 |
+
ctx.out_channel_first = out_channel_first
|
| 661 |
+
ctx.one_by_one = one_by_one
|
| 662 |
+
ctx.scans = scans
|
| 663 |
+
ctx.shape = (B, C, H, W)
|
| 664 |
+
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
|
| 665 |
+
|
| 666 |
+
y = x.new_empty((B, 2, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 2, C))
|
| 667 |
+
triton_cross_scan_flex_k2[(NH * NW, NC, B)](
|
| 668 |
+
x.contiguous(), y,
|
| 669 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
|
| 670 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 671 |
+
)
|
| 672 |
+
return y
|
| 673 |
+
|
| 674 |
+
@staticmethod
|
| 675 |
+
def backward(ctx, y: torch.Tensor):
|
| 676 |
+
in_channel_first = ctx.in_channel_first
|
| 677 |
+
out_channel_first = ctx.out_channel_first
|
| 678 |
+
one_by_one = ctx.one_by_one
|
| 679 |
+
scans = ctx.scans
|
| 680 |
+
B, C, H, W = ctx.shape
|
| 681 |
+
BC, BH, BW, NC, NH, NW = ctx.triton_shape
|
| 682 |
+
if one_by_one:
|
| 683 |
+
x = y.new_empty((B, 2, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 2, C))
|
| 684 |
+
else:
|
| 685 |
+
x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
|
| 686 |
+
|
| 687 |
+
triton_cross_scan_flex_k2[(NH * NW, NC, B)](
|
| 688 |
+
x, y.contiguous(),
|
| 689 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
|
| 690 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 691 |
+
)
|
| 692 |
+
return x, None, None, None, None
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
class CrossMergeTritonFk2(torch.autograd.Function):
|
| 696 |
+
@staticmethod
|
| 697 |
+
def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2):
|
| 698 |
+
if out_channel_first:
|
| 699 |
+
B, _, C, H, W = y.shape
|
| 700 |
+
else:
|
| 701 |
+
B, H, W, _, C = y.shape
|
| 702 |
+
B, C, H, W = int(B), int(C), int(H), int(W)
|
| 703 |
+
BC, BH, BW = 1, 32, 32
|
| 704 |
+
NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
|
| 705 |
+
ctx.in_channel_first = in_channel_first
|
| 706 |
+
ctx.out_channel_first = out_channel_first
|
| 707 |
+
ctx.one_by_one = one_by_one
|
| 708 |
+
ctx.scans = scans
|
| 709 |
+
ctx.shape = (B, C, H, W)
|
| 710 |
+
ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
|
| 711 |
+
if one_by_one:
|
| 712 |
+
x = y.new_empty((B, 2, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 2, C))
|
| 713 |
+
else:
|
| 714 |
+
x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
|
| 715 |
+
triton_cross_scan_flex_k2[(NH * NW, NC, B)](
|
| 716 |
+
x, y.contiguous(),
|
| 717 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
|
| 718 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 719 |
+
)
|
| 720 |
+
return x
|
| 721 |
+
|
| 722 |
+
@staticmethod
|
| 723 |
+
def backward(ctx, x: torch.Tensor):
|
| 724 |
+
in_channel_first = ctx.in_channel_first
|
| 725 |
+
out_channel_first = ctx.out_channel_first
|
| 726 |
+
one_by_one = ctx.one_by_one
|
| 727 |
+
scans = ctx.scans
|
| 728 |
+
B, C, H, W = ctx.shape
|
| 729 |
+
BC, BH, BW, NC, NH, NW = ctx.triton_shape
|
| 730 |
+
y = x.new_empty((B, 2, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 2, C))
|
| 731 |
+
triton_cross_scan_flex_k2[(NH * NW, NC, B)](
|
| 732 |
+
x.contiguous(), y,
|
| 733 |
+
(0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
|
| 734 |
+
BC, BH, BW, C, H, W, NH, NW
|
| 735 |
+
)
|
| 736 |
+
return y, None, None, None, None, None
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
| 740 |
+
def cross_scan_fn_k2(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
|
| 741 |
+
# x: (B, C, H, W) | (B, H, W, C) | (B, 2, C, H, W) | (B, H, W, 2, C)
|
| 742 |
+
# y: (B, 2, C, L) | (B, L, 2, C)
|
| 743 |
+
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
|
| 744 |
+
CSF = CrossScanTritonFk2 if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
|
| 745 |
+
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
|
| 746 |
+
|
| 747 |
+
# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
|
| 748 |
+
def cross_merge_fn_k2(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
|
| 749 |
+
# y: (B, 2, C, L) | (B, L, 2, C)
|
| 750 |
+
# x: (B, C, H * W) | (B, H * W, C) | (B, 2, C, H * W) | (B, H * W, 2, C)
|
| 751 |
+
# scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
|
| 752 |
+
CMF = CrossMergeTritonFk2 if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
|
| 753 |
+
return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
|
| 754 |
+
|
| 755 |
+
def cross_scan_fn_k2_torch(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
|
| 756 |
+
cross_scan = CrossScan(in_channel_first, out_channel_first, one_by_one, scans)
|
| 757 |
+
return cross_scan(x)
|
| 758 |
+
|
| 759 |
+
def cross_merge_fn_k2_torch(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False):
|
| 760 |
+
cross_merge = CrossMerge(in_channel_first, out_channel_first, one_by_one, scans)
|
| 761 |
+
return cross_merge(y)
|
| 762 |
+
|
| 763 |
+
# checks =================================================================
|
| 764 |
+
|
| 765 |
+
class CHECK:
|
| 766 |
+
def check_csm_triton():
|
| 767 |
+
B, C, H, W = 2, 192, 56, 57
|
| 768 |
+
dtype=torch.float16
|
| 769 |
+
dtype=torch.float32
|
| 770 |
+
x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
|
| 771 |
+
y = torch.randn((B, 2, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
|
| 772 |
+
x1 = x.clone().detach().requires_grad_(True)
|
| 773 |
+
y1 = y.clone().detach().requires_grad_(True)
|
| 774 |
+
|
| 775 |
+
def cross_scan(x: torch.Tensor):
|
| 776 |
+
B, C, H, W = x.shape
|
| 777 |
+
L = H * W
|
| 778 |
+
xs = torch.stack([
|
| 779 |
+
x.view(B, C, L),
|
| 780 |
+
torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
|
| 781 |
+
torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
|
| 782 |
+
torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
|
| 783 |
+
], dim=1).view(B, 4, C, L)
|
| 784 |
+
return xs
|
| 785 |
+
|
| 786 |
+
def cross_merge(out_y: torch.Tensor):
|
| 787 |
+
B, K, D, H, W = out_y.shape
|
| 788 |
+
L = H * W
|
| 789 |
+
out_y = out_y.view(B, K, D, L)
|
| 790 |
+
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
| 791 |
+
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 792 |
+
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 793 |
+
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
|
| 794 |
+
return y
|
| 795 |
+
|
| 796 |
+
def cross_scan_1b1(x: torch.Tensor):
|
| 797 |
+
B, K, C, H, W = x.shape
|
| 798 |
+
L = H * W
|
| 799 |
+
xs = torch.stack([
|
| 800 |
+
x[:, 0].view(B, C, L),
|
| 801 |
+
torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
|
| 802 |
+
torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
|
| 803 |
+
torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
|
| 804 |
+
], dim=1).view(B, 2, C, L)
|
| 805 |
+
return xs
|
| 806 |
+
|
| 807 |
+
def unidi_scan(x):
|
| 808 |
+
B, C, H, W = x.shape
|
| 809 |
+
x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
|
| 810 |
+
return x
|
| 811 |
+
|
| 812 |
+
def unidi_merge(ys):
|
| 813 |
+
B, K, C, H, W = ys.shape
|
| 814 |
+
return ys.view(B, 4, -1, H * W).sum(1)
|
| 815 |
+
|
| 816 |
+
def bidi_scan(x):
|
| 817 |
+
B, C, H, W = x.shape
|
| 818 |
+
x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
|
| 819 |
+
x = torch.cat([x, x.flip(dims=[-1])], dim=1)
|
| 820 |
+
return x
|
| 821 |
+
|
| 822 |
+
def bidi_merge(ys):
|
| 823 |
+
B, K, D, H, W = ys.shape
|
| 824 |
+
ys = ys.view(B, K, D, -1)
|
| 825 |
+
ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
| 826 |
+
return ys.contiguous().sum(1)
|
| 827 |
+
|
| 828 |
+
if True:
|
| 829 |
+
# res0 = triton.testing.do_bench(lambda :cross_scan(x))
|
| 830 |
+
res1 = triton.testing.do_bench(lambda :cross_scan_fn_k2(x, True, True, False))
|
| 831 |
+
# res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
|
| 832 |
+
# res3 = triton.testing.do_bench(lambda :cross_merge(y))
|
| 833 |
+
res4 = triton.testing.do_bench(lambda :cross_merge_fn_k2(y, True, True, False))
|
| 834 |
+
# res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
|
| 835 |
+
# print(res0, res1, res2, res3, res4, res5)
|
| 836 |
+
print(res0, res1, res3, res4)
|
| 837 |
+
res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
|
| 838 |
+
res1 = triton.testing.do_bench(lambda :cross_scan_fn_k2(x, True, True, False).sum().backward())
|
| 839 |
+
# res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
|
| 840 |
+
res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
|
| 841 |
+
res4 = triton.testing.do_bench(lambda :cross_merge_fn_k2(y, True, True, False).sum().backward())
|
| 842 |
+
# res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
|
| 843 |
+
# print(res0, res1, res2, res3, res4, res5)
|
| 844 |
+
print(res0, res1, res3, res4)
|
| 845 |
+
|
| 846 |
+
print("test cross scan")
|
| 847 |
+
for (cs0, cm0, cs1, cm1) in [
|
| 848 |
+
# channel_first -> channel_first
|
| 849 |
+
(cross_scan, cross_merge, cross_scan_fn_k2, cross_merge_fn_k2),
|
| 850 |
+
(unidi_scan, unidi_merge, lambda x: cross_scan_fn_k2(x, scans=1), lambda x: cross_merge_fn_k2(x, scans=1)),
|
| 851 |
+
(bidi_scan, bidi_merge, lambda x: cross_scan_fn_k2(x, scans=2), lambda x: cross_merge_fn_k2(x, scans=2)),
|
| 852 |
+
|
| 853 |
+
# flex: BLC->BCL; BCL->BLC; BLC->BLC;
|
| 854 |
+
(cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn_k2(x, in_channel_first=False).permute(0, 2, 1)),
|
| 855 |
+
(cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn_k2(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
|
| 856 |
+
(cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn_k2(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
|
| 857 |
+
|
| 858 |
+
# previous
|
| 859 |
+
# (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
|
| 860 |
+
# (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
|
| 861 |
+
# (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
|
| 862 |
+
]:
|
| 863 |
+
x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
|
| 864 |
+
o0 = cs0(x)
|
| 865 |
+
o1 = cs1(x1)
|
| 866 |
+
o0.backward(y.view(B, 2, C, H * W))
|
| 867 |
+
o1.backward(y.view(B, 2, C, H * W))
|
| 868 |
+
print((o0 - o1).abs().max())
|
| 869 |
+
print((x.grad - x1.grad).abs().max())
|
| 870 |
+
o0 = cm0(y)
|
| 871 |
+
o1 = cm1(y1)
|
| 872 |
+
o0.backward(x.view(B, C, H * W))
|
| 873 |
+
o1.backward(x.view(B, C, H * W))
|
| 874 |
+
print((o0 - o1).abs().max())
|
| 875 |
+
print((y.grad - y1.grad).abs().max())
|
| 876 |
+
x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
|
| 877 |
+
print("===============", flush=True)
|
| 878 |
+
|
| 879 |
+
print("test cross scan one by one")
|
| 880 |
+
for (cs0, cs1) in [
|
| 881 |
+
(cross_scan_1b1, lambda x: cross_scan_fn_k2(x, one_by_one=True)),
|
| 882 |
+
# (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
|
| 883 |
+
]:
|
| 884 |
+
o0 = cs0(y)
|
| 885 |
+
o1 = cs1(y1)
|
| 886 |
+
o0.backward(y.view(B, 2, C, H * W))
|
| 887 |
+
o1.backward(y.view(B, 2, C, H * W))
|
| 888 |
+
print((o0 - o1).abs().max())
|
| 889 |
+
print((y.grad - y1.grad).abs().max())
|
| 890 |
+
x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
|
| 891 |
+
print("===============", flush=True)
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
if __name__ == "__main__":
|
| 895 |
+
CHECK.check_csm_triton()
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
|
rscd/models/backbones/lib_mamba/csms6s.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
WITH_SELECTIVESCAN_OFLEX = True
|
| 7 |
+
WITH_SELECTIVESCAN_CORE = False
|
| 8 |
+
WITH_SELECTIVESCAN_MAMBA = True
|
| 9 |
+
try:
|
| 10 |
+
import selective_scan_cuda_oflex
|
| 11 |
+
except ImportError:
|
| 12 |
+
WITH_SELECTIVESCAN_OFLEX = False
|
| 13 |
+
warnings.warn("Can not import selective_scan_cuda_oflex. This affects speed.")
|
| 14 |
+
print("Can not import selective_scan_cuda_oflex. This affects speed.", flush=True)
|
| 15 |
+
try:
|
| 16 |
+
import selective_scan_cuda_core
|
| 17 |
+
except ImportError:
|
| 18 |
+
WITH_SELECTIVESCAN_CORE = False
|
| 19 |
+
try:
|
| 20 |
+
import selective_scan_cuda
|
| 21 |
+
except ImportError:
|
| 22 |
+
WITH_SELECTIVESCAN_MAMBA = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def selective_scan_torch(
|
| 26 |
+
u: torch.Tensor, # (B, K * C, L)
|
| 27 |
+
delta: torch.Tensor, # (B, K * C, L)
|
| 28 |
+
A: torch.Tensor, # (K * C, N)
|
| 29 |
+
B: torch.Tensor, # (B, K, N, L)
|
| 30 |
+
C: torch.Tensor, # (B, K, N, L)
|
| 31 |
+
D: torch.Tensor = None, # (K * C)
|
| 32 |
+
delta_bias: torch.Tensor = None, # (K * C)
|
| 33 |
+
delta_softplus=True,
|
| 34 |
+
oflex=True,
|
| 35 |
+
*args,
|
| 36 |
+
**kwargs
|
| 37 |
+
):
|
| 38 |
+
dtype_in = u.dtype
|
| 39 |
+
Batch, K, N, L = B.shape
|
| 40 |
+
KCdim = u.shape[1]
|
| 41 |
+
Cdim = int(KCdim / K)
|
| 42 |
+
assert u.shape == (Batch, KCdim, L)
|
| 43 |
+
assert delta.shape == (Batch, KCdim, L)
|
| 44 |
+
assert A.shape == (KCdim, N)
|
| 45 |
+
assert C.shape == B.shape
|
| 46 |
+
|
| 47 |
+
if delta_bias is not None:
|
| 48 |
+
delta = delta + delta_bias[..., None]
|
| 49 |
+
if delta_softplus:
|
| 50 |
+
delta = torch.nn.functional.softplus(delta)
|
| 51 |
+
|
| 52 |
+
u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
|
| 53 |
+
B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
|
| 54 |
+
C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
|
| 55 |
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
| 56 |
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
| 57 |
+
|
| 58 |
+
if True:
|
| 59 |
+
x = A.new_zeros((Batch, KCdim, N))
|
| 60 |
+
ys = []
|
| 61 |
+
for i in range(L):
|
| 62 |
+
x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
|
| 63 |
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
| 64 |
+
ys.append(y)
|
| 65 |
+
y = torch.stack(ys, dim=2) # (B, C, L)
|
| 66 |
+
|
| 67 |
+
out = y if D is None else y + u * D.unsqueeze(-1)
|
| 68 |
+
return out if oflex else out.to(dtype=dtype_in)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SelectiveScanCuda(torch.autograd.Function):
|
| 72 |
+
@staticmethod
|
| 73 |
+
@torch.cuda.amp.custom_fwd
|
| 74 |
+
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
|
| 75 |
+
ctx.delta_softplus = delta_softplus
|
| 76 |
+
backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
|
| 77 |
+
backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
|
| 78 |
+
backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
|
| 79 |
+
ctx.backend = backend
|
| 80 |
+
if backend == "oflex":
|
| 81 |
+
out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
|
| 82 |
+
elif backend == "core":
|
| 83 |
+
out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
|
| 84 |
+
elif backend == "mamba":
|
| 85 |
+
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
|
| 86 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
@torch.cuda.amp.custom_bwd
|
| 91 |
+
def backward(ctx, dout, *args):
|
| 92 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
| 93 |
+
backend = ctx.backend
|
| 94 |
+
if dout.stride(-1) != 1:
|
| 95 |
+
dout = dout.contiguous()
|
| 96 |
+
if backend == "oflex":
|
| 97 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
|
| 98 |
+
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
|
| 99 |
+
)
|
| 100 |
+
elif backend == "core":
|
| 101 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
|
| 102 |
+
u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
|
| 103 |
+
)
|
| 104 |
+
elif backend == "mamba":
|
| 105 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
| 106 |
+
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
|
| 107 |
+
False
|
| 108 |
+
)
|
| 109 |
+
return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def selective_scan_fn(
|
| 113 |
+
u: torch.Tensor, # (B, K * C, L)
|
| 114 |
+
delta: torch.Tensor, # (B, K * C, L)
|
| 115 |
+
A: torch.Tensor, # (K * C, N)
|
| 116 |
+
B: torch.Tensor, # (B, K, N, L)
|
| 117 |
+
C: torch.Tensor, # (B, K, N, L)
|
| 118 |
+
D: torch.Tensor = None, # (K * C)
|
| 119 |
+
delta_bias: torch.Tensor = None, # (K * C)
|
| 120 |
+
delta_softplus=True,
|
| 121 |
+
oflex=True,
|
| 122 |
+
backend=None,
|
| 123 |
+
):
|
| 124 |
+
WITH_CUDA = (WITH_SELECTIVESCAN_OFLEX or WITH_SELECTIVESCAN_CORE or WITH_SELECTIVESCAN_MAMBA)
|
| 125 |
+
fn = selective_scan_torch if backend == "torch" or (not WITH_CUDA) else SelectiveScanCuda.apply
|
| 126 |
+
return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# fvcore flops =======================================
|
| 130 |
+
def print_jit_input_names(inputs):
|
| 131 |
+
print("input params: ", end=" ", flush=True)
|
| 132 |
+
try:
|
| 133 |
+
for i in range(10):
|
| 134 |
+
print(inputs[i].debugName(), end=" ", flush=True)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
pass
|
| 137 |
+
print("", flush=True)
|
| 138 |
+
|
| 139 |
+
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
|
| 140 |
+
"""
|
| 141 |
+
u: r(B D L)
|
| 142 |
+
delta: r(B D L)
|
| 143 |
+
A: r(D N)
|
| 144 |
+
B: r(B N L)
|
| 145 |
+
C: r(B N L)
|
| 146 |
+
D: r(D)
|
| 147 |
+
z: r(B D L)
|
| 148 |
+
delta_bias: r(D), fp32
|
| 149 |
+
|
| 150 |
+
ignores:
|
| 151 |
+
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
| 152 |
+
"""
|
| 153 |
+
assert not with_complex
|
| 154 |
+
# https://github.com/state-spaces/mamba/issues/110
|
| 155 |
+
flops = 9 * B * L * D * N
|
| 156 |
+
if with_D:
|
| 157 |
+
flops += B * D * L
|
| 158 |
+
if with_Z:
|
| 159 |
+
flops += B * D * L
|
| 160 |
+
return flops
|
| 161 |
+
|
| 162 |
+
# this is only for selective_scan_ref...
|
| 163 |
+
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
|
| 164 |
+
"""
|
| 165 |
+
u: r(B D L)
|
| 166 |
+
delta: r(B D L)
|
| 167 |
+
A: r(D N)
|
| 168 |
+
B: r(B N L)
|
| 169 |
+
C: r(B N L)
|
| 170 |
+
D: r(D)
|
| 171 |
+
z: r(B D L)
|
| 172 |
+
delta_bias: r(D), fp32
|
| 173 |
+
|
| 174 |
+
ignores:
|
| 175 |
+
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
| 176 |
+
"""
|
| 177 |
+
import numpy as np
|
| 178 |
+
|
| 179 |
+
# fvcore.nn.jit_handles
|
| 180 |
+
def get_flops_einsum(input_shapes, equation):
|
| 181 |
+
np_arrs = [np.zeros(s) for s in input_shapes]
|
| 182 |
+
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
|
| 183 |
+
for line in optim.split("\n"):
|
| 184 |
+
if "optimized flop" in line.lower():
|
| 185 |
+
# divided by 2 because we count MAC (multiply-add counted as one flop)
|
| 186 |
+
flop = float(np.floor(float(line.split(":")[-1]) / 2))
|
| 187 |
+
return flop
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
assert not with_complex
|
| 191 |
+
|
| 192 |
+
flops = 0 # below code flops = 0
|
| 193 |
+
|
| 194 |
+
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
|
| 195 |
+
if with_Group:
|
| 196 |
+
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
|
| 197 |
+
else:
|
| 198 |
+
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
|
| 199 |
+
|
| 200 |
+
in_for_flops = B * D * N
|
| 201 |
+
if with_Group:
|
| 202 |
+
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
|
| 203 |
+
else:
|
| 204 |
+
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
|
| 205 |
+
flops += L * in_for_flops
|
| 206 |
+
if with_D:
|
| 207 |
+
flops += B * D * L
|
| 208 |
+
if with_Z:
|
| 209 |
+
flops += B * D * L
|
| 210 |
+
return flops
|
| 211 |
+
|
| 212 |
+
def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True):
|
| 213 |
+
if verbose:
|
| 214 |
+
print_jit_input_names(inputs)
|
| 215 |
+
flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn
|
| 216 |
+
B, D, L = inputs[0].type().sizes()
|
| 217 |
+
N = inputs[2].type().sizes()[1]
|
| 218 |
+
flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
|
| 219 |
+
return flops
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
def params(B, K, C, N, L, device = torch.device("cuda"), itype = torch.float):
|
| 224 |
+
As = (-0.5 * torch.rand(K * C, N, device=device, dtype=torch.float32)).requires_grad_()
|
| 225 |
+
Bs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
|
| 226 |
+
Cs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
|
| 227 |
+
Ds = torch.randn((K * C), device=device, dtype=torch.float32).requires_grad_()
|
| 228 |
+
u = torch.randn((B, K * C, L), device=device, dtype=itype).requires_grad_()
|
| 229 |
+
delta = (0.5 * torch.rand((B, K * C, L), device=device, dtype=itype)).requires_grad_()
|
| 230 |
+
delta_bias = (0.5 * torch.rand((K * C), device=device, dtype=torch.float32)).requires_grad_()
|
| 231 |
+
return u, delta, As, Bs, Cs, Ds, delta_bias
|
| 232 |
+
|
| 233 |
+
def bench(func, xs, Warmup=30, NTimes=20):
|
| 234 |
+
import time
|
| 235 |
+
torch.cuda.synchronize()
|
| 236 |
+
for r in range(Warmup):
|
| 237 |
+
for x in xs:
|
| 238 |
+
func(x)
|
| 239 |
+
torch.cuda.synchronize()
|
| 240 |
+
tim0 = time.time()
|
| 241 |
+
for r in range(NTimes):
|
| 242 |
+
for x in xs:
|
| 243 |
+
func(x)
|
| 244 |
+
torch.cuda.synchronize()
|
| 245 |
+
return (time.time() - tim0) / NTimes
|
| 246 |
+
|
| 247 |
+
def check():
|
| 248 |
+
u, delta, As, Bs, Cs, Ds, delta_bias = params(1, 4, 16, 8, 512, itype=torch.float16)
|
| 249 |
+
u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1 = [x.clone().detach().requires_grad_() for x in [u, delta, As, Bs, Cs, Ds, delta_bias]]
|
| 250 |
+
|
| 251 |
+
# out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="torch")
|
| 252 |
+
out = selective_scan_fn(u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1, True, backend="oflex")
|
| 253 |
+
out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="mamba")
|
| 254 |
+
print((out_ref - out).abs().max())
|
| 255 |
+
out.sum().backward()
|
| 256 |
+
out_ref.sum().backward()
|
| 257 |
+
for x, y in zip([u, As, Bs, Cs, Ds, delta, delta_bias], [u1, As1, Bs1, Cs1, Ds1, delta1, delta_bias1]):
|
| 258 |
+
print((x.grad - y.grad).abs().max())
|
| 259 |
+
|
| 260 |
+
u, delta, As, Bs, Cs, Ds, delta_bias = params(128, 4, 96, 8, 56 * 56)
|
| 261 |
+
print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="oflex"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
|
| 262 |
+
print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="mamba"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
|
| 263 |
+
print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="torch"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
|
| 264 |
+
|
| 265 |
+
check()
|
| 266 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mamba-mini
|
| 2 |
+
An efficient implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.
|
| 3 |
+
|
| 4 |
+
### mathematical derivation
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
### code
|
| 8 |
+
```python
|
| 9 |
+
import torch
|
| 10 |
+
def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64):
|
| 11 |
+
"""
|
| 12 |
+
# B: batch_size, G: groups, D: dim, N: state dim, L: seqlen
|
| 13 |
+
us: B, G * D, L
|
| 14 |
+
dts: B, G * D, L
|
| 15 |
+
As: G * D, N
|
| 16 |
+
Bs: B, G, N, L
|
| 17 |
+
Cs: B, G, N, L
|
| 18 |
+
Ds: G * D
|
| 19 |
+
delta_bias: G * D
|
| 20 |
+
# chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small
|
| 21 |
+
"""
|
| 22 |
+
def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):
|
| 23 |
+
"""
|
| 24 |
+
partial(h) / partial(t) = Ah + Bu; y = Ch + Du;
|
| 25 |
+
=> partial(h*exp(-At)) / partial(t) = Bu*exp(-At);
|
| 26 |
+
=> h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv};
|
| 27 |
+
=> h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i});
|
| 28 |
+
y_i = C_i*h_i + D*u_i
|
| 29 |
+
"""
|
| 30 |
+
"""
|
| 31 |
+
us, dts: (L, B, G, D) # L is chunk_size
|
| 32 |
+
As: (G, D, N)
|
| 33 |
+
Bs, Cs: (L, B, G, N)
|
| 34 |
+
Ds: (G, D)
|
| 35 |
+
hprefix: (B, G, D, N)
|
| 36 |
+
"""
|
| 37 |
+
ts = dts.cumsum(dim=0)
|
| 38 |
+
Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp()
|
| 39 |
+
scale = Ats[-1].detach()
|
| 40 |
+
rAts = Ats / scale
|
| 41 |
+
duts = dts * us
|
| 42 |
+
dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs)
|
| 43 |
+
hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0)
|
| 44 |
+
hs = hs_tmp + Ats * hprefix.unsqueeze(0)
|
| 45 |
+
ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs)
|
| 46 |
+
return ys, hs
|
| 47 |
+
|
| 48 |
+
inp_dtype = us.dtype
|
| 49 |
+
has_D = Ds is not None
|
| 50 |
+
|
| 51 |
+
dts = dts.float()
|
| 52 |
+
if delta_bias is not None:
|
| 53 |
+
dts = dts + delta_bias.view(1, -1, 1).float()
|
| 54 |
+
if delta_softplus:
|
| 55 |
+
dts = torch.nn.functional.softplus(dts)
|
| 56 |
+
|
| 57 |
+
if len(Bs.shape) == 3:
|
| 58 |
+
Bs = Bs.unsqueeze(1)
|
| 59 |
+
if len(Cs.shape) == 3:
|
| 60 |
+
Cs = Cs.unsqueeze(1)
|
| 61 |
+
B, G, N, L = Bs.shape
|
| 62 |
+
us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float()
|
| 63 |
+
dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float()
|
| 64 |
+
As = As.view(G, -1, N).float()
|
| 65 |
+
Bs = Bs.permute(3, 0, 1, 2).float()
|
| 66 |
+
Cs = Cs.permute(3, 0, 1, 2).float()
|
| 67 |
+
Ds = Ds.view(G, -1).float() if has_D else None
|
| 68 |
+
D = As.shape[1]
|
| 69 |
+
|
| 70 |
+
oys = []
|
| 71 |
+
# ohs = []
|
| 72 |
+
hprefix = us.new_zeros((B, G, D, N), dtype=torch.float)
|
| 73 |
+
for i in range(0, L - 1, chunksize):
|
| 74 |
+
ys, hs = selective_scan_chunk(
|
| 75 |
+
us[i:i + chunksize], dts[i:i + chunksize],
|
| 76 |
+
As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix,
|
| 77 |
+
)
|
| 78 |
+
oys.append(ys)
|
| 79 |
+
# ohs.append(hs)
|
| 80 |
+
hprefix = hs[-1]
|
| 81 |
+
|
| 82 |
+
oys = torch.cat(oys, dim=0)
|
| 83 |
+
# ohs = torch.cat(ohs, dim=0)
|
| 84 |
+
if has_D:
|
| 85 |
+
oys = oys + Ds * us
|
| 86 |
+
oys = oys.permute(1, 2, 3, 0).view(B, -1, L)
|
| 87 |
+
oys = oys.to(inp_dtype)
|
| 88 |
+
# hprefix = hprefix.to(inp_dtype)
|
| 89 |
+
|
| 90 |
+
return oys if not return_last_state else (oys, hprefix.view(B, G * D, N))
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### to test
|
| 95 |
+
```bash
|
| 96 |
+
pytest test_selective_scan.py
|
| 97 |
+
```
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fcf524eeaf71e641653c1aff1f8fac591a1e6916300d322830bd02476873ab1
|
| 3 |
+
size 34969816
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e61c23bbef4f0b8f414187d9149e6e0818ce400c3351405d494b774b988bf6d
|
| 3 |
+
size 501136
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ninja log v5
|
| 2 |
+
7 17272 1748140810026258300 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o ab3bac6bd7b8268f
|
| 3 |
+
8 23832 1748140816810431800 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o 7f9a77b388057fc6
|
| 4 |
+
7 57431 1748140850419474900 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o 3cffffbdd6b9fec1
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ninja_required_version = 1.3
|
| 2 |
+
cxx = c++
|
| 3 |
+
nvcc = /usr/local/cuda-11.8/bin/nvcc
|
| 4 |
+
|
| 5 |
+
cflags = -pthread -B /root/anaconda3/envs/rscd/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.8/include -I/root/anaconda3/envs/rscd/include/python3.8 -c
|
| 6 |
+
post_cflags = -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=selective_scan_cuda_oflex -D_GLIBCXX_USE_CXX11_ABI=0
|
| 7 |
+
cuda_cflags = -I/mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.8/include -I/root/anaconda3/envs/rscd/include/python3.8 -c
|
| 8 |
+
cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math --ptxas-options=-v -lineinfo -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=selective_scan_cuda_oflex -D_GLIBCXX_USE_CXX11_ABI=0
|
| 9 |
+
cuda_dlink_post_cflags =
|
| 10 |
+
ldflags =
|
| 11 |
+
|
| 12 |
+
rule compile
|
| 13 |
+
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
|
| 14 |
+
depfile = $out.d
|
| 15 |
+
deps = gcc
|
| 16 |
+
|
| 17 |
+
rule cuda_compile
|
| 18 |
+
depfile = $out.d
|
| 19 |
+
deps = gcc
|
| 20 |
+
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o: cuda_compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu
|
| 27 |
+
build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o: cuda_compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu
|
| 28 |
+
build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o: compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a18749c332876fbeddf0356bbb0d4c979fffd89df7f4a27796e2dd39b523f2e
|
| 3 |
+
size 12294744
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:642208da032fd86584c184e13e9ed9d18a9c6e85d925770f7ce034e0d22774a9
|
| 3 |
+
size 13211880
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0867e500c6352b2c0e1938ea0c8e6825aafbabc5699ec41d25a7793c56ed5d1e
|
| 3 |
+
size 14839600
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// WarpMask is copied from /usr/local/cuda-12.1/include/cub/util_ptx.cuh
|
| 2 |
+
// PowerOfTwo is copied from /usr/local/cuda-12.1/include/cub/util_type.cuh
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <cub/util_type.cuh>
|
| 7 |
+
#include <cub/util_arch.cuh>
|
| 8 |
+
#include <cub/util_namespace.cuh>
|
| 9 |
+
#include <cub/util_debug.cuh>
|
| 10 |
+
|
| 11 |
+
/**
|
| 12 |
+
* \brief Statically determine if N is a power-of-two
|
| 13 |
+
*/
|
| 14 |
+
template <int N>
|
| 15 |
+
struct PowerOfTwo
|
| 16 |
+
{
|
| 17 |
+
enum { VALUE = ((N & (N - 1)) == 0) };
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* @brief Returns the warp mask for a warp of @p LOGICAL_WARP_THREADS threads
|
| 23 |
+
*
|
| 24 |
+
* @par
|
| 25 |
+
* If the number of threads assigned to the virtual warp is not a power of two,
|
| 26 |
+
* it's assumed that only one virtual warp exists.
|
| 27 |
+
*
|
| 28 |
+
* @tparam LOGICAL_WARP_THREADS <b>[optional]</b> The number of threads per
|
| 29 |
+
* "logical" warp (may be less than the number of
|
| 30 |
+
* hardware warp threads).
|
| 31 |
+
* @param warp_id Id of virtual warp within architectural warp
|
| 32 |
+
*/
|
| 33 |
+
template <int LOGICAL_WARP_THREADS, int LEGACY_PTX_ARCH = 0>
|
| 34 |
+
__host__ __device__ __forceinline__
|
| 35 |
+
unsigned int WarpMask(unsigned int warp_id)
|
| 36 |
+
{
|
| 37 |
+
constexpr bool is_pow_of_two = PowerOfTwo<LOGICAL_WARP_THREADS>::VALUE;
|
| 38 |
+
constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0);
|
| 39 |
+
|
| 40 |
+
unsigned int member_mask = 0xFFFFFFFFu >>
|
| 41 |
+
(CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS);
|
| 42 |
+
|
| 43 |
+
if (is_pow_of_two && !is_arch_warp)
|
| 44 |
+
{
|
| 45 |
+
member_mask <<= warp_id * LOGICAL_WARP_THREADS;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
return member_mask;
|
| 49 |
+
}
|
| 50 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "selective_scan.h"
|
| 11 |
+
#define MAX_DSTATE 256
|
| 12 |
+
|
| 13 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 14 |
+
using weight_t = float;
|
| 15 |
+
|
| 16 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 17 |
+
if (ITYPE == at::ScalarType::Half) { \
|
| 18 |
+
using input_t = at::Half; \
|
| 19 |
+
__VA_ARGS__(); \
|
| 20 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 21 |
+
using input_t = at::BFloat16; \
|
| 22 |
+
__VA_ARGS__(); \
|
| 23 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
| 24 |
+
using input_t = float; \
|
| 25 |
+
__VA_ARGS__(); \
|
| 26 |
+
} else { \
|
| 27 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 31 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 32 |
+
|
| 33 |
+
template <int knrows, typename input_t, typename weight_t>
|
| 34 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 35 |
+
|
| 36 |
+
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
| 37 |
+
// sizes
|
| 38 |
+
const size_t batch,
|
| 39 |
+
const size_t dim,
|
| 40 |
+
const size_t seqlen,
|
| 41 |
+
const size_t dstate,
|
| 42 |
+
const size_t n_groups,
|
| 43 |
+
const size_t n_chunks,
|
| 44 |
+
// device pointers
|
| 45 |
+
const at::Tensor u,
|
| 46 |
+
const at::Tensor delta,
|
| 47 |
+
const at::Tensor A,
|
| 48 |
+
const at::Tensor B,
|
| 49 |
+
const at::Tensor C,
|
| 50 |
+
const at::Tensor out,
|
| 51 |
+
void* D_ptr,
|
| 52 |
+
void* delta_bias_ptr,
|
| 53 |
+
void* x_ptr,
|
| 54 |
+
bool delta_softplus) {
|
| 55 |
+
|
| 56 |
+
// Reset the parameters
|
| 57 |
+
memset(¶ms, 0, sizeof(params));
|
| 58 |
+
|
| 59 |
+
params.batch = batch;
|
| 60 |
+
params.dim = dim;
|
| 61 |
+
params.seqlen = seqlen;
|
| 62 |
+
params.dstate = dstate;
|
| 63 |
+
params.n_groups = n_groups;
|
| 64 |
+
params.n_chunks = n_chunks;
|
| 65 |
+
params.dim_ngroups_ratio = dim / n_groups;
|
| 66 |
+
|
| 67 |
+
params.delta_softplus = delta_softplus;
|
| 68 |
+
|
| 69 |
+
// Set the pointers and strides.
|
| 70 |
+
params.u_ptr = u.data_ptr();
|
| 71 |
+
params.delta_ptr = delta.data_ptr();
|
| 72 |
+
params.A_ptr = A.data_ptr();
|
| 73 |
+
params.B_ptr = B.data_ptr();
|
| 74 |
+
params.C_ptr = C.data_ptr();
|
| 75 |
+
params.D_ptr = D_ptr;
|
| 76 |
+
params.delta_bias_ptr = delta_bias_ptr;
|
| 77 |
+
params.out_ptr = out.data_ptr();
|
| 78 |
+
params.x_ptr = x_ptr;
|
| 79 |
+
|
| 80 |
+
// All stride are in elements, not bytes.
|
| 81 |
+
params.A_d_stride = A.stride(0);
|
| 82 |
+
params.A_dstate_stride = A.stride(1);
|
| 83 |
+
params.B_batch_stride = B.stride(0);
|
| 84 |
+
params.B_group_stride = B.stride(1);
|
| 85 |
+
params.B_dstate_stride = B.stride(2);
|
| 86 |
+
params.C_batch_stride = C.stride(0);
|
| 87 |
+
params.C_group_stride = C.stride(1);
|
| 88 |
+
params.C_dstate_stride = C.stride(2);
|
| 89 |
+
params.u_batch_stride = u.stride(0);
|
| 90 |
+
params.u_d_stride = u.stride(1);
|
| 91 |
+
params.delta_batch_stride = delta.stride(0);
|
| 92 |
+
params.delta_d_stride = delta.stride(1);
|
| 93 |
+
|
| 94 |
+
params.out_batch_stride = out.stride(0);
|
| 95 |
+
params.out_d_stride = out.stride(1);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
| 99 |
+
// sizes
|
| 100 |
+
const size_t batch,
|
| 101 |
+
const size_t dim,
|
| 102 |
+
const size_t seqlen,
|
| 103 |
+
const size_t dstate,
|
| 104 |
+
const size_t n_groups,
|
| 105 |
+
const size_t n_chunks,
|
| 106 |
+
// device pointers
|
| 107 |
+
const at::Tensor u,
|
| 108 |
+
const at::Tensor delta,
|
| 109 |
+
const at::Tensor A,
|
| 110 |
+
const at::Tensor B,
|
| 111 |
+
const at::Tensor C,
|
| 112 |
+
const at::Tensor out,
|
| 113 |
+
void* D_ptr,
|
| 114 |
+
void* delta_bias_ptr,
|
| 115 |
+
void* x_ptr,
|
| 116 |
+
const at::Tensor dout,
|
| 117 |
+
const at::Tensor du,
|
| 118 |
+
const at::Tensor ddelta,
|
| 119 |
+
const at::Tensor dA,
|
| 120 |
+
const at::Tensor dB,
|
| 121 |
+
const at::Tensor dC,
|
| 122 |
+
void* dD_ptr,
|
| 123 |
+
void* ddelta_bias_ptr,
|
| 124 |
+
bool delta_softplus) {
|
| 125 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
| 126 |
+
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks,
|
| 127 |
+
u, delta, A, B, C, dout,
|
| 128 |
+
D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
|
| 129 |
+
|
| 130 |
+
// Set the pointers and strides.
|
| 131 |
+
params.dout_ptr = dout.data_ptr();
|
| 132 |
+
params.du_ptr = du.data_ptr();
|
| 133 |
+
params.dA_ptr = dA.data_ptr();
|
| 134 |
+
params.dB_ptr = dB.data_ptr();
|
| 135 |
+
params.dC_ptr = dC.data_ptr();
|
| 136 |
+
params.dD_ptr = dD_ptr;
|
| 137 |
+
params.ddelta_ptr = ddelta.data_ptr();
|
| 138 |
+
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
| 139 |
+
// All stride are in elements, not bytes.
|
| 140 |
+
params.dout_batch_stride = dout.stride(0);
|
| 141 |
+
params.dout_d_stride = dout.stride(1);
|
| 142 |
+
params.dA_d_stride = dA.stride(0);
|
| 143 |
+
params.dA_dstate_stride = dA.stride(1);
|
| 144 |
+
params.dB_batch_stride = dB.stride(0);
|
| 145 |
+
params.dB_group_stride = dB.stride(1);
|
| 146 |
+
params.dB_dstate_stride = dB.stride(2);
|
| 147 |
+
params.dC_batch_stride = dC.stride(0);
|
| 148 |
+
params.dC_group_stride = dC.stride(1);
|
| 149 |
+
params.dC_dstate_stride = dC.stride(2);
|
| 150 |
+
params.du_batch_stride = du.stride(0);
|
| 151 |
+
params.du_d_stride = du.stride(1);
|
| 152 |
+
params.ddelta_batch_stride = ddelta.stride(0);
|
| 153 |
+
params.ddelta_d_stride = ddelta.stride(1);
|
| 154 |
+
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
std::vector<at::Tensor>
|
| 158 |
+
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
| 159 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 160 |
+
const c10::optional<at::Tensor> &D_,
|
| 161 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 162 |
+
bool delta_softplus,
|
| 163 |
+
int nrows
|
| 164 |
+
) {
|
| 165 |
+
auto input_type = u.scalar_type();
|
| 166 |
+
auto weight_type = A.scalar_type();
|
| 167 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 168 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 169 |
+
|
| 170 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 171 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 172 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 173 |
+
|
| 174 |
+
TORCH_CHECK(u.is_cuda());
|
| 175 |
+
TORCH_CHECK(delta.is_cuda());
|
| 176 |
+
TORCH_CHECK(A.is_cuda());
|
| 177 |
+
TORCH_CHECK(B.is_cuda());
|
| 178 |
+
TORCH_CHECK(C.is_cuda());
|
| 179 |
+
|
| 180 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 181 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 182 |
+
|
| 183 |
+
const auto sizes = u.sizes();
|
| 184 |
+
const int batch_size = sizes[0];
|
| 185 |
+
const int dim = sizes[1];
|
| 186 |
+
const int seqlen = sizes[2];
|
| 187 |
+
const int dstate = A.size(1);
|
| 188 |
+
const int n_groups = B.size(1);
|
| 189 |
+
|
| 190 |
+
TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
|
| 191 |
+
TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
|
| 192 |
+
|
| 193 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 194 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 195 |
+
CHECK_SHAPE(A, dim, dstate);
|
| 196 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
| 197 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 198 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
| 199 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 200 |
+
|
| 201 |
+
if (D_.has_value()) {
|
| 202 |
+
auto D = D_.value();
|
| 203 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 204 |
+
TORCH_CHECK(D.is_cuda());
|
| 205 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 206 |
+
CHECK_SHAPE(D, dim);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
if (delta_bias_.has_value()) {
|
| 210 |
+
auto delta_bias = delta_bias_.value();
|
| 211 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 212 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 213 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 214 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
|
| 218 |
+
at::Tensor out = torch::empty_like(delta);
|
| 219 |
+
at::Tensor x;
|
| 220 |
+
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
| 221 |
+
|
| 222 |
+
SSMParamsBase params;
|
| 223 |
+
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
|
| 224 |
+
u, delta, A, B, C, out,
|
| 225 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 226 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 227 |
+
x.data_ptr(),
|
| 228 |
+
delta_softplus);
|
| 229 |
+
|
| 230 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 231 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 232 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 233 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 234 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
| 235 |
+
selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream);
|
| 236 |
+
});
|
| 237 |
+
std::vector<at::Tensor> result = {out, x};
|
| 238 |
+
return result;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
std::vector<at::Tensor>
|
| 242 |
+
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
| 243 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 244 |
+
const c10::optional<at::Tensor> &D_,
|
| 245 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 246 |
+
const at::Tensor &dout,
|
| 247 |
+
const c10::optional<at::Tensor> &x_,
|
| 248 |
+
bool delta_softplus,
|
| 249 |
+
int nrows
|
| 250 |
+
) {
|
| 251 |
+
auto input_type = u.scalar_type();
|
| 252 |
+
auto weight_type = A.scalar_type();
|
| 253 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 254 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 255 |
+
|
| 256 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 257 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 258 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 259 |
+
TORCH_CHECK(dout.scalar_type() == input_type);
|
| 260 |
+
|
| 261 |
+
TORCH_CHECK(u.is_cuda());
|
| 262 |
+
TORCH_CHECK(delta.is_cuda());
|
| 263 |
+
TORCH_CHECK(A.is_cuda());
|
| 264 |
+
TORCH_CHECK(B.is_cuda());
|
| 265 |
+
TORCH_CHECK(C.is_cuda());
|
| 266 |
+
TORCH_CHECK(dout.is_cuda());
|
| 267 |
+
|
| 268 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 269 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 270 |
+
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
| 271 |
+
|
| 272 |
+
const auto sizes = u.sizes();
|
| 273 |
+
const int batch_size = sizes[0];
|
| 274 |
+
const int dim = sizes[1];
|
| 275 |
+
const int seqlen = sizes[2];
|
| 276 |
+
const int dstate = A.size(1);
|
| 277 |
+
const int n_groups = B.size(1);
|
| 278 |
+
|
| 279 |
+
TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
|
| 280 |
+
TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
|
| 281 |
+
|
| 282 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 283 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 284 |
+
CHECK_SHAPE(A, dim, dstate);
|
| 285 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
| 286 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 287 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
| 288 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 289 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 290 |
+
|
| 291 |
+
if (D_.has_value()) {
|
| 292 |
+
auto D = D_.value();
|
| 293 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 294 |
+
TORCH_CHECK(D.is_cuda());
|
| 295 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 296 |
+
CHECK_SHAPE(D, dim);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
if (delta_bias_.has_value()) {
|
| 300 |
+
auto delta_bias = delta_bias_.value();
|
| 301 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 302 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 303 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 304 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
at::Tensor out;
|
| 308 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
| 309 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
| 310 |
+
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
| 311 |
+
if (x_.has_value()) {
|
| 312 |
+
auto x = x_.value();
|
| 313 |
+
TORCH_CHECK(x.scalar_type() == weight_type);
|
| 314 |
+
TORCH_CHECK(x.is_cuda());
|
| 315 |
+
TORCH_CHECK(x.is_contiguous());
|
| 316 |
+
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
at::Tensor du = torch::empty_like(u);
|
| 320 |
+
at::Tensor ddelta = torch::empty_like(delta);
|
| 321 |
+
at::Tensor dA = torch::zeros_like(A);
|
| 322 |
+
at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
| 323 |
+
at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
| 324 |
+
at::Tensor dD;
|
| 325 |
+
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
| 326 |
+
at::Tensor ddelta_bias;
|
| 327 |
+
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
| 328 |
+
|
| 329 |
+
SSMParamsBwd params;
|
| 330 |
+
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
|
| 331 |
+
u, delta, A, B, C, out,
|
| 332 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 333 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 334 |
+
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
| 335 |
+
dout, du, ddelta, dA, dB, dC,
|
| 336 |
+
D_.has_value() ? dD.data_ptr() : nullptr,
|
| 337 |
+
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
| 338 |
+
delta_softplus);
|
| 339 |
+
|
| 340 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 341 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 342 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 343 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 344 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
| 345 |
+
selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream);
|
| 346 |
+
});
|
| 347 |
+
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
| 348 |
+
return result;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 352 |
+
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
| 353 |
+
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
| 354 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
| 11 |
+
|
| 12 |
+
#include <cub/block/block_load.cuh>
|
| 13 |
+
#include <cub/block/block_store.cuh>
|
| 14 |
+
#include <cub/block/block_scan.cuh>
|
| 15 |
+
#include <cub/block/block_reduce.cuh>
|
| 16 |
+
|
| 17 |
+
#include "selective_scan.h"
|
| 18 |
+
#include "selective_scan_common.h"
|
| 19 |
+
#include "reverse_scan.cuh"
|
| 20 |
+
#include "static_switch.h"
|
| 21 |
+
|
| 22 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_>
|
| 23 |
+
struct Selective_Scan_bwd_kernel_traits {
|
| 24 |
+
static_assert(kNItems_ % 4 == 0);
|
| 25 |
+
using input_t = input_t_;
|
| 26 |
+
using weight_t = weight_t_;
|
| 27 |
+
static constexpr int kNThreads = kNThreads_;
|
| 28 |
+
static constexpr int kNItems = kNItems_;
|
| 29 |
+
static constexpr int MaxDState = MAX_DSTATE;
|
| 30 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 31 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 32 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 33 |
+
static_assert(kNItems % kNElts == 0);
|
| 34 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 35 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 36 |
+
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
| 37 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
| 38 |
+
// For complex this would lead to massive register spilling, so we keep it at 2.
|
| 39 |
+
static constexpr int kMinBlocks = kNThreads == 128 && 3;
|
| 40 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 41 |
+
using scan_t = float2;
|
| 42 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 43 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 44 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 45 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 46 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 47 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 48 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 49 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 50 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 51 |
+
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
| 52 |
+
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
| 53 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 54 |
+
using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
|
| 55 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 56 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 57 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 58 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 59 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 60 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 61 |
+
static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
|
| 62 |
+
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
| 63 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template<typename Ktraits>
|
| 67 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 68 |
+
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
| 69 |
+
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
| 70 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 71 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 72 |
+
using input_t = typename Ktraits::input_t;
|
| 73 |
+
using weight_t = typename Ktraits::weight_t;
|
| 74 |
+
using scan_t = typename Ktraits::scan_t;
|
| 75 |
+
|
| 76 |
+
// Shared memory.
|
| 77 |
+
extern __shared__ char smem_[];
|
| 78 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 79 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 80 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 81 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 82 |
+
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 83 |
+
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
| 84 |
+
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
| 85 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
| 86 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
| 87 |
+
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
| 88 |
+
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
| 89 |
+
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads);
|
| 90 |
+
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + Ktraits::MaxDState);
|
| 91 |
+
|
| 92 |
+
const int batch_id = blockIdx.x;
|
| 93 |
+
const int dim_id = blockIdx.y;
|
| 94 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
| 95 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 96 |
+
+ dim_id * params.u_d_stride;
|
| 97 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 98 |
+
+ dim_id * params.delta_d_stride;
|
| 99 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 100 |
+
+ dim_id * params.dout_d_stride;
|
| 101 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
| 102 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 103 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 104 |
+
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
|
| 105 |
+
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
| 106 |
+
+ (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
|
| 107 |
+
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
| 108 |
+
+ (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
|
| 109 |
+
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
|
| 110 |
+
float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
| 111 |
+
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
|
| 112 |
+
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
| 113 |
+
scan_t *x = params.x_ptr == nullptr
|
| 114 |
+
? nullptr
|
| 115 |
+
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
|
| 116 |
+
float dD_val = 0;
|
| 117 |
+
float ddelta_bias_val = 0;
|
| 118 |
+
|
| 119 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 120 |
+
u += (params.n_chunks - 1) * kChunkSize;
|
| 121 |
+
delta += (params.n_chunks - 1) * kChunkSize;
|
| 122 |
+
dout += (params.n_chunks - 1) * kChunkSize;
|
| 123 |
+
Bvar += (params.n_chunks - 1) * kChunkSize;
|
| 124 |
+
Cvar += (params.n_chunks - 1) * kChunkSize;
|
| 125 |
+
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
| 126 |
+
input_t u_vals[kNItems];
|
| 127 |
+
input_t delta_vals_load[kNItems];
|
| 128 |
+
input_t dout_vals_load[kNItems];
|
| 129 |
+
__syncthreads();
|
| 130 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
| 131 |
+
__syncthreads();
|
| 132 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 133 |
+
__syncthreads();
|
| 134 |
+
load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 135 |
+
u -= kChunkSize;
|
| 136 |
+
// Will reload delta at the same location if kDeltaSoftplus
|
| 137 |
+
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
| 138 |
+
dout -= kChunkSize;
|
| 139 |
+
|
| 140 |
+
float dout_vals[kNItems], delta_vals[kNItems];
|
| 141 |
+
float du_vals[kNItems];
|
| 142 |
+
#pragma unroll
|
| 143 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 144 |
+
dout_vals[i] = float(dout_vals_load[i]);
|
| 145 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
| 146 |
+
if constexpr (kDeltaSoftplus) {
|
| 147 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
#pragma unroll
|
| 151 |
+
for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
|
| 152 |
+
#pragma unroll
|
| 153 |
+
for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
|
| 154 |
+
|
| 155 |
+
float ddelta_vals[kNItems] = {0};
|
| 156 |
+
__syncthreads();
|
| 157 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
| 158 |
+
constexpr float kLog2e = M_LOG2E;
|
| 159 |
+
weight_t A_val = A[state_idx * params.A_dstate_stride];
|
| 160 |
+
weight_t A_scaled = A_val * kLog2e;
|
| 161 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 162 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
| 163 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 164 |
+
auto &smem_load_weight_C = smem_load_weight1;
|
| 165 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
| 166 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
|
| 167 |
+
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
| 168 |
+
#pragma unroll
|
| 169 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 170 |
+
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
|
| 171 |
+
thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
| 172 |
+
if (i == 0) {
|
| 173 |
+
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState: threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp;
|
| 174 |
+
} else {
|
| 175 |
+
thread_reverse_data[i - 1].x = delta_a_exp;
|
| 176 |
+
}
|
| 177 |
+
thread_reverse_data[i].y = dout_vals[i] * C_vals[i];
|
| 178 |
+
}
|
| 179 |
+
__syncthreads();
|
| 180 |
+
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
| 181 |
+
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState])
|
| 182 |
+
: smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState];
|
| 183 |
+
// Initialize running total
|
| 184 |
+
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
|
| 185 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 186 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 187 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 188 |
+
);
|
| 189 |
+
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
|
| 190 |
+
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
| 191 |
+
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
| 192 |
+
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
| 193 |
+
);
|
| 194 |
+
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
| 195 |
+
weight_t dA_val = 0;
|
| 196 |
+
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
| 197 |
+
#pragma unroll
|
| 198 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 199 |
+
const float dx = thread_reverse_data[i].y;
|
| 200 |
+
const float ddelta_u = dx * B_vals[i];
|
| 201 |
+
du_vals[i] += ddelta_u * delta_vals[i];
|
| 202 |
+
const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
| 203 |
+
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
|
| 204 |
+
dA_val += dx * delta_vals[i] * a;
|
| 205 |
+
dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]);
|
| 206 |
+
dC_vals[i] = dout_vals[i] * thread_data[i].y;
|
| 207 |
+
}
|
| 208 |
+
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
| 209 |
+
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
| 210 |
+
auto &smem_exchange_C = smem_exchange1;
|
| 211 |
+
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
| 212 |
+
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
| 213 |
+
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
| 214 |
+
weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
| 215 |
+
#pragma unroll
|
| 216 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 217 |
+
if (i * kNThreads < seqlen_remaining) {
|
| 218 |
+
{ gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
| 219 |
+
{ gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
| 223 |
+
if (threadIdx.x == 0) {
|
| 224 |
+
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
if constexpr (kDeltaSoftplus) {
|
| 229 |
+
input_t delta_vals_load[kNItems];
|
| 230 |
+
__syncthreads();
|
| 231 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 232 |
+
delta -= kChunkSize;
|
| 233 |
+
#pragma unroll
|
| 234 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 235 |
+
float delta_val = float(delta_vals_load[i]) + delta_bias;
|
| 236 |
+
float delta_val_neg_exp = expf(-delta_val);
|
| 237 |
+
ddelta_vals[i] = delta_val <= 20.f
|
| 238 |
+
? ddelta_vals[i] / (1.f + delta_val_neg_exp)
|
| 239 |
+
: ddelta_vals[i];
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
__syncthreads();
|
| 244 |
+
#pragma unroll
|
| 245 |
+
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
|
| 246 |
+
|
| 247 |
+
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
| 248 |
+
+ dim_id * params.du_d_stride + chunk * kChunkSize;
|
| 249 |
+
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
| 250 |
+
+ dim_id * params.ddelta_d_stride + chunk * kChunkSize;
|
| 251 |
+
__syncthreads();
|
| 252 |
+
store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 253 |
+
__syncthreads();
|
| 254 |
+
store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 255 |
+
Bvar -= kChunkSize;
|
| 256 |
+
Cvar -= kChunkSize;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
if (params.dD_ptr != nullptr) {
|
| 260 |
+
__syncthreads();
|
| 261 |
+
dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
|
| 262 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
|
| 263 |
+
}
|
| 264 |
+
if (params.ddelta_bias_ptr != nullptr) {
|
| 265 |
+
__syncthreads();
|
| 266 |
+
ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
|
| 267 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
|
| 268 |
+
}
|
| 269 |
+
__syncthreads();
|
| 270 |
+
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
| 271 |
+
gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
| 276 |
+
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 277 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 278 |
+
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
| 279 |
+
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kDeltaSoftplus, input_t, weight_t>;
|
| 280 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t);
|
| 281 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 282 |
+
dim3 grid(params.batch, params.dim);
|
| 283 |
+
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
| 284 |
+
if (kSmemSize >= 48 * 1024) {
|
| 285 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 286 |
+
}
|
| 287 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 288 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 289 |
+
});
|
| 290 |
+
});
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 294 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 295 |
+
if (params.seqlen <= 128) {
|
| 296 |
+
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
|
| 297 |
+
} else if (params.seqlen <= 256) {
|
| 298 |
+
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
|
| 299 |
+
} else if (params.seqlen <= 512) {
|
| 300 |
+
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
|
| 301 |
+
} else if (params.seqlen <= 1024) {
|
| 302 |
+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
| 303 |
+
} else {
|
| 304 |
+
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
| 305 |
+
}
|
| 306 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
|
| 11 |
+
#include <cub/block/block_load.cuh>
|
| 12 |
+
#include <cub/block/block_store.cuh>
|
| 13 |
+
#include <cub/block/block_scan.cuh>
|
| 14 |
+
|
| 15 |
+
#include "selective_scan.h"
|
| 16 |
+
#include "selective_scan_common.h"
|
| 17 |
+
#include "static_switch.h"
|
| 18 |
+
|
| 19 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, typename input_t_, typename weight_t_>
|
| 20 |
+
struct Selective_Scan_fwd_kernel_traits {
|
| 21 |
+
static_assert(kNItems_ % 4 == 0);
|
| 22 |
+
using input_t = input_t_;
|
| 23 |
+
using weight_t = weight_t_;
|
| 24 |
+
static constexpr int kNThreads = kNThreads_;
|
| 25 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
| 26 |
+
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
| 27 |
+
static constexpr int kNItems = kNItems_;
|
| 28 |
+
static constexpr int MaxDState = MAX_DSTATE;
|
| 29 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 30 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 31 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 32 |
+
static_assert(kNItems % kNElts == 0);
|
| 33 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 34 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 35 |
+
|
| 36 |
+
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
| 37 |
+
|
| 38 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 39 |
+
using scan_t = float2;
|
| 40 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 41 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 42 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 43 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 44 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 45 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 46 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 47 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
| 48 |
+
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
| 49 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 50 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 51 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 52 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 53 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 54 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 55 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 56 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 57 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 58 |
+
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
template<typename Ktraits>
|
| 62 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 63 |
+
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
| 64 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 65 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 66 |
+
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
| 67 |
+
using input_t = typename Ktraits::input_t;
|
| 68 |
+
using weight_t = typename Ktraits::weight_t;
|
| 69 |
+
using scan_t = typename Ktraits::scan_t;
|
| 70 |
+
|
| 71 |
+
// Shared memory.
|
| 72 |
+
extern __shared__ char smem_[];
|
| 73 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 74 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 75 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 76 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 77 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 78 |
+
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
| 79 |
+
|
| 80 |
+
const int batch_id = blockIdx.x;
|
| 81 |
+
const int dim_id = blockIdx.y;
|
| 82 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
| 83 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 84 |
+
+ dim_id * params.u_d_stride;
|
| 85 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 86 |
+
+ dim_id * params.delta_d_stride;
|
| 87 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
| 88 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 89 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 90 |
+
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate;
|
| 91 |
+
|
| 92 |
+
float D_val = 0; // attention!
|
| 93 |
+
if (params.D_ptr != nullptr) {
|
| 94 |
+
D_val = reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
| 95 |
+
}
|
| 96 |
+
float delta_bias = 0;
|
| 97 |
+
if (params.delta_bias_ptr != nullptr) {
|
| 98 |
+
delta_bias = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 102 |
+
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
| 103 |
+
input_t u_vals[kNItems], delta_vals_load[kNItems];
|
| 104 |
+
__syncthreads();
|
| 105 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
| 106 |
+
if constexpr (!kDirectIO) { __syncthreads(); }
|
| 107 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 108 |
+
u += kChunkSize;
|
| 109 |
+
delta += kChunkSize;
|
| 110 |
+
|
| 111 |
+
float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems];
|
| 112 |
+
#pragma unroll
|
| 113 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 114 |
+
float u_val = float(u_vals[i]);
|
| 115 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
| 116 |
+
if (params.delta_softplus) {
|
| 117 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
| 118 |
+
}
|
| 119 |
+
delta_u_vals[i] = delta_vals[i] * u_val;
|
| 120 |
+
out_vals[i] = D_val * u_val;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
__syncthreads();
|
| 124 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
| 125 |
+
constexpr float kLog2e = M_LOG2E;
|
| 126 |
+
weight_t A_val = A[state_idx * params.A_dstate_stride];
|
| 127 |
+
A_val *= kLog2e;
|
| 128 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 129 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
| 130 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 131 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
| 132 |
+
smem_load_weight1, (params.seqlen - chunk * kChunkSize));
|
| 133 |
+
__syncthreads();
|
| 134 |
+
scan_t thread_data[kNItems];
|
| 135 |
+
#pragma unroll
|
| 136 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 137 |
+
thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]);
|
| 138 |
+
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
| 139 |
+
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
| 140 |
+
thread_data[i] = make_float2(1.f, 0.f);
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
// Initialize running total
|
| 145 |
+
scan_t running_prefix;
|
| 146 |
+
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
| 147 |
+
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
| 148 |
+
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
| 149 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 150 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 151 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 152 |
+
);
|
| 153 |
+
// There's a syncthreads in the scan op, so we don't need to sync here.
|
| 154 |
+
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
| 155 |
+
if (threadIdx.x == 0) {
|
| 156 |
+
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
| 157 |
+
x[chunk * params.dstate + state_idx] = prefix_op.running_prefix;
|
| 158 |
+
}
|
| 159 |
+
#pragma unroll
|
| 160 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 161 |
+
out_vals[i] += thread_data[i].y * C_vals[i];
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 166 |
+
+ dim_id * params.out_d_stride + chunk * kChunkSize;
|
| 167 |
+
__syncthreads();
|
| 168 |
+
store_output<Ktraits>(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 169 |
+
Bvar += kChunkSize;
|
| 170 |
+
Cvar += kChunkSize;
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
| 175 |
+
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 176 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 177 |
+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, input_t, weight_t>;
|
| 178 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t);
|
| 179 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 180 |
+
dim3 grid(params.batch, params.dim);
|
| 181 |
+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
| 182 |
+
if (kSmemSize >= 48 * 1024) {
|
| 183 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 184 |
+
}
|
| 185 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 186 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 187 |
+
});
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 191 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 192 |
+
if (params.seqlen <= 128) {
|
| 193 |
+
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
| 194 |
+
} else if (params.seqlen <= 256) {
|
| 195 |
+
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
| 196 |
+
} else if (params.seqlen <= 512) {
|
| 197 |
+
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
| 198 |
+
} else if (params.seqlen <= 1024) {
|
| 199 |
+
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
| 200 |
+
} else {
|
| 201 |
+
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
| 202 |
+
}
|
| 203 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
| 11 |
+
|
| 12 |
+
#include <cub/block/block_load.cuh>
|
| 13 |
+
#include <cub/block/block_store.cuh>
|
| 14 |
+
#include <cub/block/block_scan.cuh>
|
| 15 |
+
#include <cub/block/block_reduce.cuh>
|
| 16 |
+
|
| 17 |
+
#include "selective_scan_ndstate.h"
|
| 18 |
+
#include "selective_scan_common.h"
|
| 19 |
+
#include "reverse_scan.cuh"
|
| 20 |
+
#include "static_switch.h"
|
| 21 |
+
|
| 22 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_>
|
| 23 |
+
struct Selective_Scan_bwd_kernel_traits {
|
| 24 |
+
static_assert(kNItems_ % 4 == 0);
|
| 25 |
+
using input_t = input_t_;
|
| 26 |
+
using weight_t = weight_t_;
|
| 27 |
+
static constexpr int kNThreads = kNThreads_;
|
| 28 |
+
static constexpr int kNItems = kNItems_;
|
| 29 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 30 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 31 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 32 |
+
static_assert(kNItems % kNElts == 0);
|
| 33 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 34 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 35 |
+
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
| 36 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
| 37 |
+
// For complex this would lead to massive register spilling, so we keep it at 2.
|
| 38 |
+
static constexpr int kMinBlocks = kNThreads == 128 && 3;
|
| 39 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 40 |
+
using scan_t = float2;
|
| 41 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 42 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 43 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 44 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 45 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 46 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 47 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 48 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 49 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 50 |
+
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
| 51 |
+
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
| 52 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 53 |
+
using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
|
| 54 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 55 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 56 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 57 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 58 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 59 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 60 |
+
static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
|
| 61 |
+
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
| 62 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
template<typename Ktraits>
|
| 66 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 67 |
+
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
| 68 |
+
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
| 69 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 70 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 71 |
+
using input_t = typename Ktraits::input_t;
|
| 72 |
+
using weight_t = typename Ktraits::weight_t;
|
| 73 |
+
using scan_t = typename Ktraits::scan_t;
|
| 74 |
+
|
| 75 |
+
// Shared memory.
|
| 76 |
+
extern __shared__ char smem_[];
|
| 77 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 78 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 79 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 80 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 81 |
+
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 82 |
+
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
| 83 |
+
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
| 84 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
| 85 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
| 86 |
+
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
| 87 |
+
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
| 88 |
+
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 + kNThreads);
|
| 89 |
+
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + 1);
|
| 90 |
+
|
| 91 |
+
const int batch_id = blockIdx.x;
|
| 92 |
+
const int dim_id = blockIdx.y;
|
| 93 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
| 94 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 95 |
+
+ dim_id * params.u_d_stride;
|
| 96 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 97 |
+
+ dim_id * params.delta_d_stride;
|
| 98 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 99 |
+
+ dim_id * params.dout_d_stride;
|
| 100 |
+
weight_t A_val = reinterpret_cast<weight_t *>(params.A_ptr)[dim_id];
|
| 101 |
+
constexpr float kLog2e = M_LOG2E;
|
| 102 |
+
weight_t A_scaled = A_val * kLog2e;
|
| 103 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 104 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 105 |
+
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id;
|
| 106 |
+
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
| 107 |
+
+ (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
|
| 108 |
+
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
| 109 |
+
+ (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
|
| 110 |
+
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
|
| 111 |
+
float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
| 112 |
+
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
|
| 113 |
+
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
| 114 |
+
scan_t *x = params.x_ptr == nullptr
|
| 115 |
+
? nullptr
|
| 116 |
+
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks);
|
| 117 |
+
float dD_val = 0;
|
| 118 |
+
float ddelta_bias_val = 0;
|
| 119 |
+
|
| 120 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 121 |
+
u += (params.n_chunks - 1) * kChunkSize;
|
| 122 |
+
delta += (params.n_chunks - 1) * kChunkSize;
|
| 123 |
+
dout += (params.n_chunks - 1) * kChunkSize;
|
| 124 |
+
Bvar += (params.n_chunks - 1) * kChunkSize;
|
| 125 |
+
Cvar += (params.n_chunks - 1) * kChunkSize;
|
| 126 |
+
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
| 127 |
+
input_t u_vals[kNItems];
|
| 128 |
+
input_t delta_vals_load[kNItems];
|
| 129 |
+
input_t dout_vals_load[kNItems];
|
| 130 |
+
__syncthreads();
|
| 131 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
| 132 |
+
__syncthreads();
|
| 133 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 134 |
+
__syncthreads();
|
| 135 |
+
load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 136 |
+
u -= kChunkSize;
|
| 137 |
+
// Will reload delta at the same location if kDeltaSoftplus
|
| 138 |
+
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
| 139 |
+
dout -= kChunkSize;
|
| 140 |
+
|
| 141 |
+
float dout_vals[kNItems], delta_vals[kNItems];
|
| 142 |
+
float du_vals[kNItems];
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 145 |
+
dout_vals[i] = float(dout_vals_load[i]);
|
| 146 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
| 147 |
+
if constexpr (kDeltaSoftplus) {
|
| 148 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
#pragma unroll
|
| 152 |
+
for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
|
| 153 |
+
#pragma unroll
|
| 154 |
+
for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
|
| 155 |
+
|
| 156 |
+
float ddelta_vals[kNItems] = {0};
|
| 157 |
+
__syncthreads();
|
| 158 |
+
{
|
| 159 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 160 |
+
load_weight<Ktraits>(Bvar, B_vals,
|
| 161 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 162 |
+
auto &smem_load_weight_C = smem_load_weight1;
|
| 163 |
+
load_weight<Ktraits>(Cvar, C_vals,
|
| 164 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
|
| 165 |
+
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 168 |
+
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
|
| 169 |
+
thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
| 170 |
+
if (i == 0) {
|
| 171 |
+
smem_delta_a[threadIdx.x == 0 ? (chunk % 2): threadIdx.x + 2] = delta_a_exp;
|
| 172 |
+
} else {
|
| 173 |
+
thread_reverse_data[i - 1].x = delta_a_exp;
|
| 174 |
+
}
|
| 175 |
+
thread_reverse_data[i].y = dout_vals[i] * C_vals[i];
|
| 176 |
+
}
|
| 177 |
+
__syncthreads();
|
| 178 |
+
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
| 179 |
+
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[(chunk + 1) % 2])
|
| 180 |
+
: smem_delta_a[threadIdx.x + 1 + 2];
|
| 181 |
+
// Initialize running total
|
| 182 |
+
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[chunk - 1] : make_float2(1.f, 0.f);
|
| 183 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 184 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 185 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 186 |
+
);
|
| 187 |
+
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[0] : make_float2(1.f, 0.f);
|
| 188 |
+
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
| 189 |
+
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
| 190 |
+
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
| 191 |
+
);
|
| 192 |
+
if (threadIdx.x == 0) { smem_running_postfix[0] = postfix_op.running_prefix; }
|
| 193 |
+
weight_t dA_val = 0;
|
| 194 |
+
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
| 195 |
+
#pragma unroll
|
| 196 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 197 |
+
const float dx = thread_reverse_data[i].y;
|
| 198 |
+
const float ddelta_u = dx * B_vals[i];
|
| 199 |
+
du_vals[i] += ddelta_u * delta_vals[i];
|
| 200 |
+
const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
| 201 |
+
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
|
| 202 |
+
dA_val += dx * delta_vals[i] * a;
|
| 203 |
+
dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]);
|
| 204 |
+
dC_vals[i] = dout_vals[i] * thread_data[i].y;
|
| 205 |
+
}
|
| 206 |
+
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
| 207 |
+
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
| 208 |
+
auto &smem_exchange_C = smem_exchange1;
|
| 209 |
+
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
| 210 |
+
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
| 211 |
+
weight_t *dB_cur = dB + chunk * kChunkSize + threadIdx.x;
|
| 212 |
+
weight_t *dC_cur = dC + chunk * kChunkSize + threadIdx.x;
|
| 213 |
+
#pragma unroll
|
| 214 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 215 |
+
if (i * kNThreads < seqlen_remaining) {
|
| 216 |
+
{ gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
| 217 |
+
{ gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
| 221 |
+
if (threadIdx.x == 0) {
|
| 222 |
+
smem_da[0] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[0];
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
if constexpr (kDeltaSoftplus) {
|
| 227 |
+
input_t delta_vals_load[kNItems];
|
| 228 |
+
__syncthreads();
|
| 229 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 230 |
+
delta -= kChunkSize;
|
| 231 |
+
#pragma unroll
|
| 232 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 233 |
+
float delta_val = float(delta_vals_load[i]) + delta_bias;
|
| 234 |
+
float delta_val_neg_exp = expf(-delta_val);
|
| 235 |
+
ddelta_vals[i] = delta_val <= 20.f
|
| 236 |
+
? ddelta_vals[i] / (1.f + delta_val_neg_exp)
|
| 237 |
+
: ddelta_vals[i];
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
__syncthreads();
|
| 242 |
+
#pragma unroll
|
| 243 |
+
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
|
| 244 |
+
|
| 245 |
+
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
| 246 |
+
+ dim_id * params.du_d_stride + chunk * kChunkSize;
|
| 247 |
+
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
| 248 |
+
+ dim_id * params.ddelta_d_stride + chunk * kChunkSize;
|
| 249 |
+
__syncthreads();
|
| 250 |
+
store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 251 |
+
__syncthreads();
|
| 252 |
+
store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 253 |
+
Bvar -= kChunkSize;
|
| 254 |
+
Cvar -= kChunkSize;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if (params.dD_ptr != nullptr) {
|
| 258 |
+
__syncthreads();
|
| 259 |
+
dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
|
| 260 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
|
| 261 |
+
}
|
| 262 |
+
if (params.ddelta_bias_ptr != nullptr) {
|
| 263 |
+
__syncthreads();
|
| 264 |
+
ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
|
| 265 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
|
| 266 |
+
}
|
| 267 |
+
__syncthreads();
|
| 268 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(dA, smem_da[0]); }
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
| 272 |
+
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 273 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 274 |
+
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
| 275 |
+
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kDeltaSoftplus, input_t, weight_t>;
|
| 276 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t) + (kNThreads + 4) * sizeof(typename Ktraits::weight_t);
|
| 277 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 278 |
+
dim3 grid(params.batch, params.dim);
|
| 279 |
+
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
| 280 |
+
if (kSmemSize >= 48 * 1024) {
|
| 281 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 282 |
+
}
|
| 283 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 284 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 285 |
+
});
|
| 286 |
+
});
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 290 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 291 |
+
if (params.seqlen <= 128) {
|
| 292 |
+
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
|
| 293 |
+
} else if (params.seqlen <= 256) {
|
| 294 |
+
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
|
| 295 |
+
} else if (params.seqlen <= 512) {
|
| 296 |
+
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
|
| 297 |
+
} else if (params.seqlen <= 1024) {
|
| 298 |
+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
| 299 |
+
} else {
|
| 300 |
+
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
| 301 |
+
}
|
| 302 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel_ndstate.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel_ndstate.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
|
| 11 |
+
#include <cub/block/block_load.cuh>
|
| 12 |
+
#include <cub/block/block_store.cuh>
|
| 13 |
+
#include <cub/block/block_scan.cuh>
|
| 14 |
+
|
| 15 |
+
#include "selective_scan_ndstate.h"
|
| 16 |
+
#include "selective_scan_common.h"
|
| 17 |
+
#include "static_switch.h"
|
| 18 |
+
|
| 19 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, typename input_t_, typename weight_t_>
|
| 20 |
+
struct Selective_Scan_fwd_kernel_traits {
|
| 21 |
+
static_assert(kNItems_ % 4 == 0);
|
| 22 |
+
using input_t = input_t_;
|
| 23 |
+
using weight_t = weight_t_;
|
| 24 |
+
static constexpr int kNThreads = kNThreads_;
|
| 25 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
| 26 |
+
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
| 27 |
+
static constexpr int kNItems = kNItems_;
|
| 28 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 29 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 30 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 31 |
+
static_assert(kNItems % kNElts == 0);
|
| 32 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 33 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 34 |
+
|
| 35 |
+
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
| 36 |
+
|
| 37 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 38 |
+
using scan_t = float2;
|
| 39 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 40 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 41 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 42 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 43 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 44 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 45 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 46 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
| 47 |
+
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
| 48 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 49 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 50 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 51 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 52 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 53 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 54 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 55 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 56 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 57 |
+
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
template<typename Ktraits>
|
| 61 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 62 |
+
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
| 63 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 64 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 65 |
+
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
| 66 |
+
using input_t = typename Ktraits::input_t;
|
| 67 |
+
using weight_t = typename Ktraits::weight_t;
|
| 68 |
+
using scan_t = typename Ktraits::scan_t;
|
| 69 |
+
|
| 70 |
+
// Shared memory.
|
| 71 |
+
extern __shared__ char smem_[];
|
| 72 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 73 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 74 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 75 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 76 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 77 |
+
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
| 78 |
+
|
| 79 |
+
const int batch_id = blockIdx.x;
|
| 80 |
+
const int dim_id = blockIdx.y;
|
| 81 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
| 82 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 83 |
+
+ dim_id * params.u_d_stride;
|
| 84 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 85 |
+
+ dim_id * params.delta_d_stride;
|
| 86 |
+
constexpr float kLog2e = M_LOG2E;
|
| 87 |
+
weight_t A_val = reinterpret_cast<weight_t *>(params.A_ptr)[dim_id] * kLog2e;
|
| 88 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 89 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 90 |
+
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks;
|
| 91 |
+
|
| 92 |
+
float D_val = 0; // attention!
|
| 93 |
+
if (params.D_ptr != nullptr) {
|
| 94 |
+
D_val = reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
| 95 |
+
}
|
| 96 |
+
float delta_bias = 0;
|
| 97 |
+
if (params.delta_bias_ptr != nullptr) {
|
| 98 |
+
delta_bias = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 102 |
+
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
| 103 |
+
input_t u_vals[kNItems], delta_vals_load[kNItems];
|
| 104 |
+
__syncthreads();
|
| 105 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
| 106 |
+
if constexpr (!kDirectIO) { __syncthreads(); }
|
| 107 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 108 |
+
u += kChunkSize;
|
| 109 |
+
delta += kChunkSize;
|
| 110 |
+
|
| 111 |
+
float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems];
|
| 112 |
+
#pragma unroll
|
| 113 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 114 |
+
float u_val = float(u_vals[i]);
|
| 115 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
| 116 |
+
if (params.delta_softplus) {
|
| 117 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
| 118 |
+
}
|
| 119 |
+
delta_u_vals[i] = delta_vals[i] * u_val;
|
| 120 |
+
out_vals[i] = D_val * u_val;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
__syncthreads();
|
| 124 |
+
{
|
| 125 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 126 |
+
load_weight<Ktraits>(Bvar, B_vals,
|
| 127 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 128 |
+
auto &smem_load_weight_C = smem_load_weight1;
|
| 129 |
+
load_weight<Ktraits>(Cvar, C_vals,
|
| 130 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
|
| 131 |
+
__syncthreads();
|
| 132 |
+
scan_t thread_data[kNItems];
|
| 133 |
+
#pragma unroll
|
| 134 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 135 |
+
thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]);
|
| 136 |
+
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
| 137 |
+
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
| 138 |
+
thread_data[i] = make_float2(1.f, 0.f);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
// Initialize running total
|
| 143 |
+
scan_t running_prefix;
|
| 144 |
+
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
| 145 |
+
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[0] : make_float2(1.f, 0.f);
|
| 146 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 147 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 148 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 149 |
+
);
|
| 150 |
+
// There's a syncthreads in the scan op, so we don't need to sync here.
|
| 151 |
+
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
| 152 |
+
if (threadIdx.x == 0) {
|
| 153 |
+
smem_running_prefix[0] = prefix_op.running_prefix;
|
| 154 |
+
x[chunk] = prefix_op.running_prefix;
|
| 155 |
+
}
|
| 156 |
+
#pragma unroll
|
| 157 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 158 |
+
out_vals[i] += thread_data[i].y * C_vals[i];
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 163 |
+
+ dim_id * params.out_d_stride + chunk * kChunkSize;
|
| 164 |
+
__syncthreads();
|
| 165 |
+
store_output<Ktraits>(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 166 |
+
Bvar += kChunkSize;
|
| 167 |
+
Cvar += kChunkSize;
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
| 172 |
+
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 173 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 174 |
+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, input_t, weight_t>;
|
| 175 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t);
|
| 176 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 177 |
+
dim3 grid(params.batch, params.dim);
|
| 178 |
+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
| 179 |
+
if (kSmemSize >= 48 * 1024) {
|
| 180 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 181 |
+
}
|
| 182 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 183 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 184 |
+
});
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 188 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 189 |
+
if (params.seqlen <= 128) {
|
| 190 |
+
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
| 191 |
+
} else if (params.seqlen <= 256) {
|
| 192 |
+
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
| 193 |
+
} else if (params.seqlen <= 512) {
|
| 194 |
+
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
| 195 |
+
} else if (params.seqlen <= 1024) {
|
| 196 |
+
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
| 197 |
+
} else {
|
| 198 |
+
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
| 199 |
+
}
|
| 200 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "selective_scan_ndstate.h"
|
| 11 |
+
#define MAX_DSTATE 256
|
| 12 |
+
|
| 13 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 14 |
+
using weight_t = float;
|
| 15 |
+
|
| 16 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 17 |
+
if (ITYPE == at::ScalarType::Half) { \
|
| 18 |
+
using input_t = at::Half; \
|
| 19 |
+
__VA_ARGS__(); \
|
| 20 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 21 |
+
using input_t = at::BFloat16; \
|
| 22 |
+
__VA_ARGS__(); \
|
| 23 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
| 24 |
+
using input_t = float; \
|
| 25 |
+
__VA_ARGS__(); \
|
| 26 |
+
} else { \
|
| 27 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 31 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 32 |
+
|
| 33 |
+
template <int knrows, typename input_t, typename weight_t>
|
| 34 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 35 |
+
|
| 36 |
+
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
| 37 |
+
// sizes
|
| 38 |
+
const size_t batch,
|
| 39 |
+
const size_t dim,
|
| 40 |
+
const size_t seqlen,
|
| 41 |
+
const size_t n_groups,
|
| 42 |
+
const size_t n_chunks,
|
| 43 |
+
// device pointers
|
| 44 |
+
const at::Tensor u,
|
| 45 |
+
const at::Tensor delta,
|
| 46 |
+
const at::Tensor A,
|
| 47 |
+
const at::Tensor B,
|
| 48 |
+
const at::Tensor C,
|
| 49 |
+
const at::Tensor out,
|
| 50 |
+
void* D_ptr,
|
| 51 |
+
void* delta_bias_ptr,
|
| 52 |
+
void* x_ptr,
|
| 53 |
+
bool delta_softplus) {
|
| 54 |
+
|
| 55 |
+
// Reset the parameters
|
| 56 |
+
memset(¶ms, 0, sizeof(params));
|
| 57 |
+
|
| 58 |
+
params.batch = batch;
|
| 59 |
+
params.dim = dim;
|
| 60 |
+
params.seqlen = seqlen;
|
| 61 |
+
params.n_groups = n_groups;
|
| 62 |
+
params.n_chunks = n_chunks;
|
| 63 |
+
params.dim_ngroups_ratio = dim / n_groups;
|
| 64 |
+
|
| 65 |
+
params.delta_softplus = delta_softplus;
|
| 66 |
+
|
| 67 |
+
// Set the pointers and strides.
|
| 68 |
+
params.u_ptr = u.data_ptr();
|
| 69 |
+
params.delta_ptr = delta.data_ptr();
|
| 70 |
+
params.A_ptr = A.data_ptr();
|
| 71 |
+
params.B_ptr = B.data_ptr();
|
| 72 |
+
params.C_ptr = C.data_ptr();
|
| 73 |
+
params.D_ptr = D_ptr;
|
| 74 |
+
params.delta_bias_ptr = delta_bias_ptr;
|
| 75 |
+
params.out_ptr = out.data_ptr();
|
| 76 |
+
params.x_ptr = x_ptr;
|
| 77 |
+
|
| 78 |
+
// All stride are in elements, not bytes.
|
| 79 |
+
params.A_d_stride = A.stride(0);
|
| 80 |
+
params.B_batch_stride = B.stride(0);
|
| 81 |
+
params.B_group_stride = B.stride(1);
|
| 82 |
+
params.C_batch_stride = C.stride(0);
|
| 83 |
+
params.C_group_stride = C.stride(1);
|
| 84 |
+
params.u_batch_stride = u.stride(0);
|
| 85 |
+
params.u_d_stride = u.stride(1);
|
| 86 |
+
params.delta_batch_stride = delta.stride(0);
|
| 87 |
+
params.delta_d_stride = delta.stride(1);
|
| 88 |
+
|
| 89 |
+
params.out_batch_stride = out.stride(0);
|
| 90 |
+
params.out_d_stride = out.stride(1);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
| 94 |
+
// sizes
|
| 95 |
+
const size_t batch,
|
| 96 |
+
const size_t dim,
|
| 97 |
+
const size_t seqlen,
|
| 98 |
+
const size_t n_groups,
|
| 99 |
+
const size_t n_chunks,
|
| 100 |
+
// device pointers
|
| 101 |
+
const at::Tensor u,
|
| 102 |
+
const at::Tensor delta,
|
| 103 |
+
const at::Tensor A,
|
| 104 |
+
const at::Tensor B,
|
| 105 |
+
const at::Tensor C,
|
| 106 |
+
const at::Tensor out,
|
| 107 |
+
void* D_ptr,
|
| 108 |
+
void* delta_bias_ptr,
|
| 109 |
+
void* x_ptr,
|
| 110 |
+
const at::Tensor dout,
|
| 111 |
+
const at::Tensor du,
|
| 112 |
+
const at::Tensor ddelta,
|
| 113 |
+
const at::Tensor dA,
|
| 114 |
+
const at::Tensor dB,
|
| 115 |
+
const at::Tensor dC,
|
| 116 |
+
void* dD_ptr,
|
| 117 |
+
void* ddelta_bias_ptr,
|
| 118 |
+
bool delta_softplus) {
|
| 119 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
| 120 |
+
set_ssm_params_fwd(params, batch, dim, seqlen, n_groups, n_chunks,
|
| 121 |
+
u, delta, A, B, C, dout,
|
| 122 |
+
D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
|
| 123 |
+
|
| 124 |
+
// Set the pointers and strides.
|
| 125 |
+
params.dout_ptr = dout.data_ptr();
|
| 126 |
+
params.du_ptr = du.data_ptr();
|
| 127 |
+
params.dA_ptr = dA.data_ptr();
|
| 128 |
+
params.dB_ptr = dB.data_ptr();
|
| 129 |
+
params.dC_ptr = dC.data_ptr();
|
| 130 |
+
params.dD_ptr = dD_ptr;
|
| 131 |
+
params.ddelta_ptr = ddelta.data_ptr();
|
| 132 |
+
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
| 133 |
+
// All stride are in elements, not bytes.
|
| 134 |
+
params.dout_batch_stride = dout.stride(0);
|
| 135 |
+
params.dout_d_stride = dout.stride(1);
|
| 136 |
+
params.dA_d_stride = dA.stride(0);
|
| 137 |
+
params.dB_batch_stride = dB.stride(0);
|
| 138 |
+
params.dB_group_stride = dB.stride(1);
|
| 139 |
+
params.dC_batch_stride = dC.stride(0);
|
| 140 |
+
params.dC_group_stride = dC.stride(1);
|
| 141 |
+
params.du_batch_stride = du.stride(0);
|
| 142 |
+
params.du_d_stride = du.stride(1);
|
| 143 |
+
params.ddelta_batch_stride = ddelta.stride(0);
|
| 144 |
+
params.ddelta_d_stride = ddelta.stride(1);
|
| 145 |
+
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
std::vector<at::Tensor>
|
| 149 |
+
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
| 150 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 151 |
+
const c10::optional<at::Tensor> &D_,
|
| 152 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 153 |
+
bool delta_softplus,
|
| 154 |
+
int nrows
|
| 155 |
+
) {
|
| 156 |
+
auto input_type = u.scalar_type();
|
| 157 |
+
auto weight_type = A.scalar_type();
|
| 158 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 159 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 160 |
+
|
| 161 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 162 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 163 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 164 |
+
|
| 165 |
+
TORCH_CHECK(u.is_cuda());
|
| 166 |
+
TORCH_CHECK(delta.is_cuda());
|
| 167 |
+
TORCH_CHECK(A.is_cuda());
|
| 168 |
+
TORCH_CHECK(B.is_cuda());
|
| 169 |
+
TORCH_CHECK(C.is_cuda());
|
| 170 |
+
|
| 171 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 172 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 173 |
+
|
| 174 |
+
const auto sizes = u.sizes();
|
| 175 |
+
const int batch_size = sizes[0];
|
| 176 |
+
const int dim = sizes[1];
|
| 177 |
+
const int seqlen = sizes[2];
|
| 178 |
+
const int n_groups = B.size(1);
|
| 179 |
+
|
| 180 |
+
TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
|
| 181 |
+
|
| 182 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 183 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 184 |
+
CHECK_SHAPE(A, dim);
|
| 185 |
+
CHECK_SHAPE(B, batch_size, n_groups, seqlen);
|
| 186 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 187 |
+
CHECK_SHAPE(C, batch_size, n_groups, seqlen);
|
| 188 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 189 |
+
|
| 190 |
+
if (D_.has_value()) {
|
| 191 |
+
auto D = D_.value();
|
| 192 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 193 |
+
TORCH_CHECK(D.is_cuda());
|
| 194 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 195 |
+
CHECK_SHAPE(D, dim);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
if (delta_bias_.has_value()) {
|
| 199 |
+
auto delta_bias = delta_bias_.value();
|
| 200 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 201 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 202 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 203 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
|
| 207 |
+
at::Tensor out = torch::empty_like(delta);
|
| 208 |
+
at::Tensor x;
|
| 209 |
+
x = torch::empty({batch_size, dim, n_chunks, 1 * 2}, u.options().dtype(weight_type));
|
| 210 |
+
|
| 211 |
+
SSMParamsBase params;
|
| 212 |
+
set_ssm_params_fwd(params, batch_size, dim, seqlen, n_groups, n_chunks,
|
| 213 |
+
u, delta, A, B, C, out,
|
| 214 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 215 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 216 |
+
x.data_ptr(),
|
| 217 |
+
delta_softplus);
|
| 218 |
+
|
| 219 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 220 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 221 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 222 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 223 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
| 224 |
+
selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream);
|
| 225 |
+
});
|
| 226 |
+
std::vector<at::Tensor> result = {out, x};
|
| 227 |
+
return result;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
std::vector<at::Tensor>
|
| 231 |
+
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
| 232 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 233 |
+
const c10::optional<at::Tensor> &D_,
|
| 234 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 235 |
+
const at::Tensor &dout,
|
| 236 |
+
const c10::optional<at::Tensor> &x_,
|
| 237 |
+
bool delta_softplus,
|
| 238 |
+
int nrows
|
| 239 |
+
) {
|
| 240 |
+
auto input_type = u.scalar_type();
|
| 241 |
+
auto weight_type = A.scalar_type();
|
| 242 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 243 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 244 |
+
|
| 245 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 246 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 247 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 248 |
+
TORCH_CHECK(dout.scalar_type() == input_type);
|
| 249 |
+
|
| 250 |
+
TORCH_CHECK(u.is_cuda());
|
| 251 |
+
TORCH_CHECK(delta.is_cuda());
|
| 252 |
+
TORCH_CHECK(A.is_cuda());
|
| 253 |
+
TORCH_CHECK(B.is_cuda());
|
| 254 |
+
TORCH_CHECK(C.is_cuda());
|
| 255 |
+
TORCH_CHECK(dout.is_cuda());
|
| 256 |
+
|
| 257 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 258 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 259 |
+
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
| 260 |
+
|
| 261 |
+
const auto sizes = u.sizes();
|
| 262 |
+
const int batch_size = sizes[0];
|
| 263 |
+
const int dim = sizes[1];
|
| 264 |
+
const int seqlen = sizes[2];
|
| 265 |
+
const int n_groups = B.size(1);
|
| 266 |
+
|
| 267 |
+
TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
|
| 268 |
+
|
| 269 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 270 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 271 |
+
CHECK_SHAPE(A, dim);
|
| 272 |
+
CHECK_SHAPE(B, batch_size, n_groups, seqlen);
|
| 273 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 274 |
+
CHECK_SHAPE(C, batch_size, n_groups, seqlen);
|
| 275 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 276 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 277 |
+
|
| 278 |
+
if (D_.has_value()) {
|
| 279 |
+
auto D = D_.value();
|
| 280 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 281 |
+
TORCH_CHECK(D.is_cuda());
|
| 282 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 283 |
+
CHECK_SHAPE(D, dim);
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
if (delta_bias_.has_value()) {
|
| 287 |
+
auto delta_bias = delta_bias_.value();
|
| 288 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 289 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 290 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 291 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
at::Tensor out;
|
| 295 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
| 296 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
| 297 |
+
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
| 298 |
+
if (x_.has_value()) {
|
| 299 |
+
auto x = x_.value();
|
| 300 |
+
TORCH_CHECK(x.scalar_type() == weight_type);
|
| 301 |
+
TORCH_CHECK(x.is_cuda());
|
| 302 |
+
TORCH_CHECK(x.is_contiguous());
|
| 303 |
+
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * 1);
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
at::Tensor du = torch::empty_like(u);
|
| 307 |
+
at::Tensor ddelta = torch::empty_like(delta);
|
| 308 |
+
at::Tensor dA = torch::zeros_like(A);
|
| 309 |
+
at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
| 310 |
+
at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
| 311 |
+
at::Tensor dD;
|
| 312 |
+
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
| 313 |
+
at::Tensor ddelta_bias;
|
| 314 |
+
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
| 315 |
+
|
| 316 |
+
SSMParamsBwd params;
|
| 317 |
+
set_ssm_params_bwd(params, batch_size, dim, seqlen, n_groups, n_chunks,
|
| 318 |
+
u, delta, A, B, C, out,
|
| 319 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 320 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 321 |
+
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
| 322 |
+
dout, du, ddelta, dA, dB, dC,
|
| 323 |
+
D_.has_value() ? dD.data_ptr() : nullptr,
|
| 324 |
+
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
| 325 |
+
delta_softplus);
|
| 326 |
+
|
| 327 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 328 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 329 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 330 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 331 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
| 332 |
+
selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream);
|
| 333 |
+
});
|
| 334 |
+
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
| 335 |
+
return result;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 339 |
+
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
| 340 |
+
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
| 341 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
|
| 9 |
+
struct SSMScanParamsBase {
|
| 10 |
+
using index_t = uint32_t;
|
| 11 |
+
|
| 12 |
+
int batch, seqlen, n_chunks;
|
| 13 |
+
index_t a_batch_stride;
|
| 14 |
+
index_t b_batch_stride;
|
| 15 |
+
index_t out_batch_stride;
|
| 16 |
+
|
| 17 |
+
// Common data pointers.
|
| 18 |
+
void *__restrict__ a_ptr;
|
| 19 |
+
void *__restrict__ b_ptr;
|
| 20 |
+
void *__restrict__ out_ptr;
|
| 21 |
+
void *__restrict__ x_ptr;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 25 |
+
|
| 26 |
+
struct SSMParamsBase {
|
| 27 |
+
using index_t = uint32_t;
|
| 28 |
+
|
| 29 |
+
int batch, dim, seqlen, n_groups, n_chunks;
|
| 30 |
+
int dim_ngroups_ratio;
|
| 31 |
+
|
| 32 |
+
bool delta_softplus;
|
| 33 |
+
|
| 34 |
+
index_t A_d_stride;
|
| 35 |
+
index_t B_batch_stride;
|
| 36 |
+
index_t B_d_stride;
|
| 37 |
+
index_t B_group_stride;
|
| 38 |
+
index_t C_batch_stride;
|
| 39 |
+
index_t C_d_stride;
|
| 40 |
+
index_t C_group_stride;
|
| 41 |
+
index_t u_batch_stride;
|
| 42 |
+
index_t u_d_stride;
|
| 43 |
+
index_t delta_batch_stride;
|
| 44 |
+
index_t delta_d_stride;
|
| 45 |
+
index_t out_batch_stride;
|
| 46 |
+
index_t out_d_stride;
|
| 47 |
+
|
| 48 |
+
// Common data pointers.
|
| 49 |
+
void *__restrict__ A_ptr;
|
| 50 |
+
void *__restrict__ B_ptr;
|
| 51 |
+
void *__restrict__ C_ptr;
|
| 52 |
+
void *__restrict__ D_ptr;
|
| 53 |
+
void *__restrict__ u_ptr;
|
| 54 |
+
void *__restrict__ delta_ptr;
|
| 55 |
+
void *__restrict__ delta_bias_ptr;
|
| 56 |
+
void *__restrict__ out_ptr;
|
| 57 |
+
void *__restrict__ x_ptr;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
struct SSMParamsBwd: public SSMParamsBase {
|
| 61 |
+
index_t dout_batch_stride;
|
| 62 |
+
index_t dout_d_stride;
|
| 63 |
+
index_t dA_d_stride;
|
| 64 |
+
index_t dB_batch_stride;
|
| 65 |
+
index_t dB_group_stride;
|
| 66 |
+
index_t dB_d_stride;
|
| 67 |
+
index_t dC_batch_stride;
|
| 68 |
+
index_t dC_group_stride;
|
| 69 |
+
index_t dC_d_stride;
|
| 70 |
+
index_t du_batch_stride;
|
| 71 |
+
index_t du_d_stride;
|
| 72 |
+
index_t ddelta_batch_stride;
|
| 73 |
+
index_t ddelta_d_stride;
|
| 74 |
+
|
| 75 |
+
// Common data pointers.
|
| 76 |
+
void *__restrict__ dout_ptr;
|
| 77 |
+
void *__restrict__ dA_ptr;
|
| 78 |
+
void *__restrict__ dB_ptr;
|
| 79 |
+
void *__restrict__ dC_ptr;
|
| 80 |
+
void *__restrict__ dD_ptr;
|
| 81 |
+
void *__restrict__ du_ptr;
|
| 82 |
+
void *__restrict__ ddelta_ptr;
|
| 83 |
+
void *__restrict__ ddelta_bias_ptr;
|
| 84 |
+
};
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
| 11 |
+
|
| 12 |
+
#include <cub/block/block_load.cuh>
|
| 13 |
+
#include <cub/block/block_store.cuh>
|
| 14 |
+
#include <cub/block/block_scan.cuh>
|
| 15 |
+
#include <cub/block/block_reduce.cuh>
|
| 16 |
+
|
| 17 |
+
#include "selective_scan.h"
|
| 18 |
+
#include "selective_scan_common.h"
|
| 19 |
+
#include "reverse_scan.cuh"
|
| 20 |
+
#include "static_switch.h"
|
| 21 |
+
|
| 22 |
+
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_>
|
| 23 |
+
struct Selective_Scan_bwd_kernel_traits {
|
| 24 |
+
static_assert(kNItems_ % 4 == 0);
|
| 25 |
+
using input_t = input_t_;
|
| 26 |
+
using weight_t = weight_t_;
|
| 27 |
+
static constexpr int kNThreads = kNThreads_;
|
| 28 |
+
static constexpr int kNItems = kNItems_;
|
| 29 |
+
static constexpr int kNRows = kNRows_;
|
| 30 |
+
static constexpr int MaxDState = MAX_DSTATE / kNRows_;
|
| 31 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 32 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 33 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 34 |
+
static_assert(kNItems % kNElts == 0);
|
| 35 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 36 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 37 |
+
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
| 38 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
| 39 |
+
// For complex this would lead to massive register spilling, so we keep it at 2.
|
| 40 |
+
static constexpr int kMinBlocks = kNThreads == 128 && 3;
|
| 41 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 42 |
+
using scan_t = float2;
|
| 43 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 44 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 45 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 46 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 47 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 48 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 49 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 50 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 51 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 52 |
+
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
| 53 |
+
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
| 54 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 55 |
+
using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
|
| 56 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 57 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 58 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 59 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 60 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 61 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 62 |
+
static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
|
| 63 |
+
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
| 64 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
template<typename Ktraits>
|
| 68 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 69 |
+
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
| 70 |
+
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
| 71 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 72 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 73 |
+
constexpr int kNRows = Ktraits::kNRows;
|
| 74 |
+
using input_t = typename Ktraits::input_t;
|
| 75 |
+
using weight_t = typename Ktraits::weight_t;
|
| 76 |
+
using scan_t = typename Ktraits::scan_t;
|
| 77 |
+
|
| 78 |
+
// Shared memory.
|
| 79 |
+
extern __shared__ char smem_[];
|
| 80 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 81 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 82 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 83 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 84 |
+
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 85 |
+
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
| 86 |
+
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
| 87 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
| 88 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
| 89 |
+
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
| 90 |
+
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
| 91 |
+
// scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + kNRows * (2 * Ktraits::MaxDState + kNThreads));
|
| 92 |
+
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads);
|
| 93 |
+
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + kNRows * Ktraits::MaxDState);
|
| 94 |
+
|
| 95 |
+
const int batch_id = blockIdx.x;
|
| 96 |
+
const int dim_id = blockIdx.y;
|
| 97 |
+
const int dim_id_nrow = dim_id * kNRows;
|
| 98 |
+
const int group_id = dim_id_nrow / (params.dim_ngroups_ratio);
|
| 99 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 100 |
+
+ dim_id_nrow * params.u_d_stride;
|
| 101 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 102 |
+
+ dim_id_nrow * params.delta_d_stride;
|
| 103 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 104 |
+
+ dim_id_nrow * params.dout_d_stride;
|
| 105 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id_nrow * params.A_d_stride;
|
| 106 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 107 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 108 |
+
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id_nrow * params.dA_d_stride;
|
| 109 |
+
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
| 110 |
+
+ (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
|
| 111 |
+
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
| 112 |
+
+ (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
|
| 113 |
+
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id_nrow;
|
| 114 |
+
float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.D_ptr) + dim_id_nrow;
|
| 115 |
+
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id_nrow;
|
| 116 |
+
float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.delta_bias_ptr) + dim_id_nrow;
|
| 117 |
+
scan_t *x = params.x_ptr == nullptr
|
| 118 |
+
? nullptr
|
| 119 |
+
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * (params.n_chunks) * params.dstate;
|
| 120 |
+
float dD_val[kNRows] = {0};
|
| 121 |
+
float ddelta_bias_val[kNRows] = {0};
|
| 122 |
+
|
| 123 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 124 |
+
u += (params.n_chunks - 1) * kChunkSize;
|
| 125 |
+
delta += (params.n_chunks - 1) * kChunkSize;
|
| 126 |
+
dout += (params.n_chunks - 1) * kChunkSize;
|
| 127 |
+
Bvar += (params.n_chunks - 1) * kChunkSize;
|
| 128 |
+
Cvar += (params.n_chunks - 1) * kChunkSize;
|
| 129 |
+
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
| 130 |
+
input_t u_vals[kNRows][kNItems];
|
| 131 |
+
input_t delta_vals_load[kNRows][kNItems];
|
| 132 |
+
input_t dout_vals_load[kNRows][kNItems];
|
| 133 |
+
#pragma unroll
|
| 134 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 135 |
+
__syncthreads();
|
| 136 |
+
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
| 137 |
+
__syncthreads();
|
| 138 |
+
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
| 139 |
+
__syncthreads();
|
| 140 |
+
load_input<Ktraits>(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
| 141 |
+
}
|
| 142 |
+
u -= kChunkSize;
|
| 143 |
+
// Will reload delta at the same location if kDeltaSoftplus
|
| 144 |
+
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
| 145 |
+
dout -= kChunkSize;
|
| 146 |
+
|
| 147 |
+
float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems];
|
| 148 |
+
float du_vals[kNRows][kNItems];
|
| 149 |
+
#pragma unroll
|
| 150 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 151 |
+
#pragma unroll
|
| 152 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 153 |
+
dout_vals[r][i] = float(dout_vals_load[r][i]);
|
| 154 |
+
delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr? 0: delta_bias[r]);
|
| 155 |
+
if constexpr (kDeltaSoftplus) {
|
| 156 |
+
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
#pragma unroll
|
| 160 |
+
for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr? 0: D_val[r]) * dout_vals[r][i]; }
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); }
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
float ddelta_vals[kNRows][kNItems] = {0};
|
| 166 |
+
__syncthreads();
|
| 167 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
| 168 |
+
weight_t A_val[kNRows];
|
| 169 |
+
weight_t A_scaled[kNRows];
|
| 170 |
+
#pragma unroll
|
| 171 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 172 |
+
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
| 173 |
+
constexpr float kLog2e = M_LOG2E;
|
| 174 |
+
A_scaled[r] = A_val[r] * kLog2e;
|
| 175 |
+
}
|
| 176 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 177 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
| 178 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 179 |
+
auto &smem_load_weight_C = smem_load_weight1;
|
| 180 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
| 181 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
|
| 182 |
+
#pragma unroll
|
| 183 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 184 |
+
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
| 185 |
+
#pragma unroll
|
| 186 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 187 |
+
const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]);
|
| 188 |
+
thread_data[i] = make_float2(delta_a_exp, delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]);
|
| 189 |
+
if (i == 0) {
|
| 190 |
+
// smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads) : threadIdx.x + 2 * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)] = delta_a_exp;
|
| 191 |
+
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState : threadIdx.x + kNRows * 2 * Ktraits::MaxDState] = delta_a_exp;
|
| 192 |
+
|
| 193 |
+
} else {
|
| 194 |
+
thread_reverse_data[i - 1].x = delta_a_exp;
|
| 195 |
+
}
|
| 196 |
+
thread_reverse_data[i].y = dout_vals[r][i] * C_vals[i];
|
| 197 |
+
}
|
| 198 |
+
__syncthreads();
|
| 199 |
+
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
| 200 |
+
// ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)])
|
| 201 |
+
// : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)];
|
| 202 |
+
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState])
|
| 203 |
+
: smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState];
|
| 204 |
+
// Initialize running total
|
| 205 |
+
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(r * params.n_chunks + chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
|
| 206 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 207 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 208 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 209 |
+
);
|
| 210 |
+
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f);
|
| 211 |
+
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
| 212 |
+
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
| 213 |
+
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
| 214 |
+
);
|
| 215 |
+
if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; }
|
| 216 |
+
weight_t dA_val = 0;
|
| 217 |
+
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
| 218 |
+
#pragma unroll
|
| 219 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 220 |
+
const float dx = thread_reverse_data[i].y;
|
| 221 |
+
const float ddelta_u = dx * B_vals[i];
|
| 222 |
+
du_vals[r][i] += ddelta_u * delta_vals[r][i];
|
| 223 |
+
const float a = thread_data[i].y - (delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]);
|
| 224 |
+
ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a;
|
| 225 |
+
dA_val += dx * delta_vals[r][i] * a;
|
| 226 |
+
dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]);
|
| 227 |
+
dC_vals[i] = dout_vals[r][i] * thread_data[i].y;
|
| 228 |
+
}
|
| 229 |
+
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
| 230 |
+
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
| 231 |
+
auto &smem_exchange_C = smem_exchange1;
|
| 232 |
+
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
| 233 |
+
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
| 234 |
+
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
| 235 |
+
weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
| 236 |
+
#pragma unroll
|
| 237 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 238 |
+
if (i * kNThreads < seqlen_remaining) {
|
| 239 |
+
{ gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
| 240 |
+
{ gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
| 244 |
+
if (threadIdx.x == 0) {
|
| 245 |
+
smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState];
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
if constexpr (kDeltaSoftplus) {
|
| 251 |
+
input_t delta_vals_load[kNRows][kNItems];
|
| 252 |
+
#pragma unroll
|
| 253 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 254 |
+
__syncthreads();
|
| 255 |
+
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
| 256 |
+
}
|
| 257 |
+
delta -= kChunkSize;
|
| 258 |
+
#pragma unroll
|
| 259 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 260 |
+
#pragma unroll
|
| 261 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 262 |
+
float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr? 0: delta_bias[r]);
|
| 263 |
+
float delta_val_neg_exp = expf(-delta_val);
|
| 264 |
+
ddelta_vals[r][i] = delta_val <= 20.f
|
| 265 |
+
? ddelta_vals[r][i] / (1.f + delta_val_neg_exp)
|
| 266 |
+
: ddelta_vals[r][i];
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
__syncthreads();
|
| 272 |
+
#pragma unroll
|
| 273 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 274 |
+
#pragma unroll
|
| 275 |
+
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; }
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
| 279 |
+
+ dim_id_nrow * params.du_d_stride + chunk * kChunkSize;
|
| 280 |
+
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
| 281 |
+
+ dim_id_nrow * params.ddelta_d_stride + chunk * kChunkSize;
|
| 282 |
+
#pragma unroll
|
| 283 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 284 |
+
__syncthreads();
|
| 285 |
+
store_output<Ktraits>(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
| 286 |
+
__syncthreads();
|
| 287 |
+
store_output<Ktraits>(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
Bvar -= kChunkSize;
|
| 291 |
+
Cvar -= kChunkSize;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
#pragma unroll
|
| 295 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 296 |
+
if (params.dD_ptr != nullptr) {
|
| 297 |
+
__syncthreads();
|
| 298 |
+
dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]);
|
| 299 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); }
|
| 300 |
+
}
|
| 301 |
+
if (params.ddelta_bias_ptr != nullptr) {
|
| 302 |
+
__syncthreads();
|
| 303 |
+
ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]);
|
| 304 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); }
|
| 305 |
+
}
|
| 306 |
+
__syncthreads();
|
| 307 |
+
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
| 308 |
+
gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]);
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
template<int kNThreads, int kNItems, int kNRows, typename input_t, typename weight_t>
|
| 314 |
+
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 315 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 316 |
+
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
| 317 |
+
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kDeltaSoftplus, input_t, weight_t>;
|
| 318 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * kNRows * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t);
|
| 319 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 320 |
+
dim3 grid(params.batch, params.dim / kNRows);
|
| 321 |
+
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
| 322 |
+
if (kSmemSize >= 48 * 1024) {
|
| 323 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 324 |
+
}
|
| 325 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 326 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 327 |
+
});
|
| 328 |
+
});
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 332 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 333 |
+
if (params.seqlen <= 128) {
|
| 334 |
+
selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream);
|
| 335 |
+
} else if (params.seqlen <= 256) {
|
| 336 |
+
selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream);
|
| 337 |
+
} else if (params.seqlen <= 512) {
|
| 338 |
+
selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream);
|
| 339 |
+
} else if (params.seqlen <= 1024) {
|
| 340 |
+
selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream);
|
| 341 |
+
} else {
|
| 342 |
+
selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream);
|
| 343 |
+
}
|
| 344 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel_nrow.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
|
| 11 |
+
#include <cub/block/block_load.cuh>
|
| 12 |
+
#include <cub/block/block_store.cuh>
|
| 13 |
+
#include <cub/block/block_scan.cuh>
|
| 14 |
+
|
| 15 |
+
#include "selective_scan.h"
|
| 16 |
+
#include "selective_scan_common.h"
|
| 17 |
+
#include "static_switch.h"
|
| 18 |
+
|
| 19 |
+
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, typename input_t_, typename weight_t_>
|
| 20 |
+
struct Selective_Scan_fwd_kernel_traits {
|
| 21 |
+
static_assert(kNItems_ % 4 == 0);
|
| 22 |
+
using input_t = input_t_;
|
| 23 |
+
using weight_t = weight_t_;
|
| 24 |
+
static constexpr int kNThreads = kNThreads_;
|
| 25 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
| 26 |
+
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
| 27 |
+
static constexpr int kNItems = kNItems_;
|
| 28 |
+
static constexpr int kNRows = kNRows_;
|
| 29 |
+
static constexpr int MaxDState = MAX_DSTATE / kNRows;
|
| 30 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 31 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 32 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 33 |
+
static_assert(kNItems % kNElts == 0);
|
| 34 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 35 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 36 |
+
|
| 37 |
+
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
| 38 |
+
|
| 39 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 40 |
+
using scan_t = float2;
|
| 41 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 42 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 43 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 44 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 45 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 46 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 47 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 48 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
| 49 |
+
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
| 50 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 51 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 52 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 53 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 54 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 55 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 56 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 57 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 58 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 59 |
+
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
template<typename Ktraits>
|
| 63 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 64 |
+
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
| 65 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 66 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 67 |
+
constexpr int kNRows = Ktraits::kNRows;
|
| 68 |
+
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
| 69 |
+
using input_t = typename Ktraits::input_t;
|
| 70 |
+
using weight_t = typename Ktraits::weight_t;
|
| 71 |
+
using scan_t = typename Ktraits::scan_t;
|
| 72 |
+
|
| 73 |
+
// Shared memory.
|
| 74 |
+
extern __shared__ char smem_[];
|
| 75 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 76 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 77 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 78 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 79 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 80 |
+
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
| 81 |
+
|
| 82 |
+
const int batch_id = blockIdx.x;
|
| 83 |
+
const int dim_id = blockIdx.y;
|
| 84 |
+
const int dim_id_nrow = dim_id * kNRows;
|
| 85 |
+
const int group_id = dim_id_nrow / (params.dim_ngroups_ratio);
|
| 86 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 87 |
+
+ dim_id_nrow * params.u_d_stride;
|
| 88 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 89 |
+
+ dim_id_nrow * params.delta_d_stride;
|
| 90 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id_nrow * params.A_d_stride;
|
| 91 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 92 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 93 |
+
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * params.n_chunks * params.dstate;
|
| 94 |
+
|
| 95 |
+
float D_val[kNRows] = {0};
|
| 96 |
+
if (params.D_ptr != nullptr) {
|
| 97 |
+
#pragma unroll
|
| 98 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 99 |
+
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id_nrow + r];
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
float delta_bias[kNRows] = {0};
|
| 103 |
+
if (params.delta_bias_ptr != nullptr) {
|
| 104 |
+
#pragma unroll
|
| 105 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 106 |
+
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id_nrow + r];
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 111 |
+
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
| 112 |
+
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
| 113 |
+
__syncthreads();
|
| 114 |
+
#pragma unroll
|
| 115 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 116 |
+
if constexpr (!kDirectIO) {
|
| 117 |
+
if (r > 0) { __syncthreads(); }
|
| 118 |
+
}
|
| 119 |
+
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
| 120 |
+
if constexpr (!kDirectIO) { __syncthreads(); }
|
| 121 |
+
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
| 122 |
+
}
|
| 123 |
+
u += kChunkSize;
|
| 124 |
+
delta += kChunkSize;
|
| 125 |
+
|
| 126 |
+
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
| 127 |
+
#pragma unroll
|
| 128 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 129 |
+
#pragma unroll
|
| 130 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 131 |
+
float u_val = float(u_vals[r][i]);
|
| 132 |
+
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
| 133 |
+
if (params.delta_softplus) {
|
| 134 |
+
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
| 135 |
+
}
|
| 136 |
+
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
| 137 |
+
out_vals[r][i] = D_val[r] * u_val;
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
__syncthreads();
|
| 142 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
| 143 |
+
weight_t A_val[kNRows];
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 146 |
+
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
| 147 |
+
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
| 148 |
+
constexpr float kLog2e = M_LOG2E;
|
| 149 |
+
A_val[r] *= kLog2e;
|
| 150 |
+
}
|
| 151 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 152 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
| 153 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 154 |
+
auto &smem_load_weight_C = smem_load_weight1;
|
| 155 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
| 156 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
|
| 157 |
+
#pragma unroll
|
| 158 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 159 |
+
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
| 160 |
+
scan_t thread_data[kNItems];
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 163 |
+
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
| 164 |
+
B_vals[i] * delta_u_vals[r][i]);
|
| 165 |
+
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
| 166 |
+
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
| 167 |
+
thread_data[i] = make_float2(1.f, 0.f);
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
// Initialize running total
|
| 172 |
+
scan_t running_prefix;
|
| 173 |
+
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
| 174 |
+
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f);
|
| 175 |
+
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
| 176 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 177 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 178 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 179 |
+
);
|
| 180 |
+
// There's a syncthreads in the scan op, so we don't need to sync here.
|
| 181 |
+
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
| 182 |
+
if (threadIdx.x == 0) {
|
| 183 |
+
smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix;
|
| 184 |
+
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
| 185 |
+
}
|
| 186 |
+
#pragma unroll
|
| 187 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 188 |
+
out_vals[r][i] += thread_data[i].y * C_vals[i];
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 194 |
+
+ dim_id_nrow * params.out_d_stride + chunk * kChunkSize;
|
| 195 |
+
__syncthreads();
|
| 196 |
+
#pragma unroll
|
| 197 |
+
for (int r = 0; r < kNRows; ++r) {
|
| 198 |
+
if constexpr (!kDirectIO) {
|
| 199 |
+
if (r > 0) { __syncthreads(); }
|
| 200 |
+
}
|
| 201 |
+
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
Bvar += kChunkSize;
|
| 205 |
+
Cvar += kChunkSize;
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
template<int kNThreads, int kNItems, int kNRows, typename input_t, typename weight_t>
|
| 210 |
+
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 211 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 212 |
+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, input_t, weight_t>;
|
| 213 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t);
|
| 214 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 215 |
+
dim3 grid(params.batch, params.dim / kNRows);
|
| 216 |
+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
| 217 |
+
if (kSmemSize >= 48 * 1024) {
|
| 218 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 219 |
+
}
|
| 220 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 221 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 222 |
+
});
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 226 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 227 |
+
if (params.seqlen <= 128) {
|
| 228 |
+
selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream);
|
| 229 |
+
} else if (params.seqlen <= 256) {
|
| 230 |
+
selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream);
|
| 231 |
+
} else if (params.seqlen <= 512) {
|
| 232 |
+
selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream);
|
| 233 |
+
} else if (params.seqlen <= 1024) {
|
| 234 |
+
selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream);
|
| 235 |
+
} else {
|
| 236 |
+
selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream);
|
| 237 |
+
}
|
| 238 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "selective_scan.h"
|
| 11 |
+
#define MAX_DSTATE 256
|
| 12 |
+
|
| 13 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 14 |
+
using weight_t = float;
|
| 15 |
+
|
| 16 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 17 |
+
if (ITYPE == at::ScalarType::Half) { \
|
| 18 |
+
using input_t = at::Half; \
|
| 19 |
+
__VA_ARGS__(); \
|
| 20 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 21 |
+
using input_t = at::BFloat16; \
|
| 22 |
+
__VA_ARGS__(); \
|
| 23 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
| 24 |
+
using input_t = float; \
|
| 25 |
+
__VA_ARGS__(); \
|
| 26 |
+
} else { \
|
| 27 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
#define INT_SWITCH(INT, NAME, ...) [&] { \
|
| 31 |
+
if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \
|
| 32 |
+
else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \
|
| 33 |
+
else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \
|
| 34 |
+
else {constexpr int NAME = 1; __VA_ARGS__(); } \
|
| 35 |
+
}() \
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
template<int knrows, typename input_t, typename weight_t>
|
| 39 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 40 |
+
|
| 41 |
+
template <int knrows, typename input_t, typename weight_t>
|
| 42 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 43 |
+
|
| 44 |
+
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
| 45 |
+
// sizes
|
| 46 |
+
const size_t batch,
|
| 47 |
+
const size_t dim,
|
| 48 |
+
const size_t seqlen,
|
| 49 |
+
const size_t dstate,
|
| 50 |
+
const size_t n_groups,
|
| 51 |
+
const size_t n_chunks,
|
| 52 |
+
// device pointers
|
| 53 |
+
const at::Tensor u,
|
| 54 |
+
const at::Tensor delta,
|
| 55 |
+
const at::Tensor A,
|
| 56 |
+
const at::Tensor B,
|
| 57 |
+
const at::Tensor C,
|
| 58 |
+
const at::Tensor out,
|
| 59 |
+
void* D_ptr,
|
| 60 |
+
void* delta_bias_ptr,
|
| 61 |
+
void* x_ptr,
|
| 62 |
+
bool delta_softplus) {
|
| 63 |
+
|
| 64 |
+
// Reset the parameters
|
| 65 |
+
memset(¶ms, 0, sizeof(params));
|
| 66 |
+
|
| 67 |
+
params.batch = batch;
|
| 68 |
+
params.dim = dim;
|
| 69 |
+
params.seqlen = seqlen;
|
| 70 |
+
params.dstate = dstate;
|
| 71 |
+
params.n_groups = n_groups;
|
| 72 |
+
params.n_chunks = n_chunks;
|
| 73 |
+
params.dim_ngroups_ratio = dim / n_groups;
|
| 74 |
+
|
| 75 |
+
params.delta_softplus = delta_softplus;
|
| 76 |
+
|
| 77 |
+
// Set the pointers and strides.
|
| 78 |
+
params.u_ptr = u.data_ptr();
|
| 79 |
+
params.delta_ptr = delta.data_ptr();
|
| 80 |
+
params.A_ptr = A.data_ptr();
|
| 81 |
+
params.B_ptr = B.data_ptr();
|
| 82 |
+
params.C_ptr = C.data_ptr();
|
| 83 |
+
params.D_ptr = D_ptr;
|
| 84 |
+
params.delta_bias_ptr = delta_bias_ptr;
|
| 85 |
+
params.out_ptr = out.data_ptr();
|
| 86 |
+
params.x_ptr = x_ptr;
|
| 87 |
+
|
| 88 |
+
// All stride are in elements, not bytes.
|
| 89 |
+
params.A_d_stride = A.stride(0);
|
| 90 |
+
params.A_dstate_stride = A.stride(1);
|
| 91 |
+
params.B_batch_stride = B.stride(0);
|
| 92 |
+
params.B_group_stride = B.stride(1);
|
| 93 |
+
params.B_dstate_stride = B.stride(2);
|
| 94 |
+
params.C_batch_stride = C.stride(0);
|
| 95 |
+
params.C_group_stride = C.stride(1);
|
| 96 |
+
params.C_dstate_stride = C.stride(2);
|
| 97 |
+
params.u_batch_stride = u.stride(0);
|
| 98 |
+
params.u_d_stride = u.stride(1);
|
| 99 |
+
params.delta_batch_stride = delta.stride(0);
|
| 100 |
+
params.delta_d_stride = delta.stride(1);
|
| 101 |
+
|
| 102 |
+
params.out_batch_stride = out.stride(0);
|
| 103 |
+
params.out_d_stride = out.stride(1);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
| 107 |
+
// sizes
|
| 108 |
+
const size_t batch,
|
| 109 |
+
const size_t dim,
|
| 110 |
+
const size_t seqlen,
|
| 111 |
+
const size_t dstate,
|
| 112 |
+
const size_t n_groups,
|
| 113 |
+
const size_t n_chunks,
|
| 114 |
+
// device pointers
|
| 115 |
+
const at::Tensor u,
|
| 116 |
+
const at::Tensor delta,
|
| 117 |
+
const at::Tensor A,
|
| 118 |
+
const at::Tensor B,
|
| 119 |
+
const at::Tensor C,
|
| 120 |
+
const at::Tensor out,
|
| 121 |
+
void* D_ptr,
|
| 122 |
+
void* delta_bias_ptr,
|
| 123 |
+
void* x_ptr,
|
| 124 |
+
const at::Tensor dout,
|
| 125 |
+
const at::Tensor du,
|
| 126 |
+
const at::Tensor ddelta,
|
| 127 |
+
const at::Tensor dA,
|
| 128 |
+
const at::Tensor dB,
|
| 129 |
+
const at::Tensor dC,
|
| 130 |
+
void* dD_ptr,
|
| 131 |
+
void* ddelta_bias_ptr,
|
| 132 |
+
bool delta_softplus) {
|
| 133 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
| 134 |
+
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks,
|
| 135 |
+
u, delta, A, B, C, dout,
|
| 136 |
+
D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
|
| 137 |
+
|
| 138 |
+
// Set the pointers and strides.
|
| 139 |
+
params.dout_ptr = dout.data_ptr();
|
| 140 |
+
params.du_ptr = du.data_ptr();
|
| 141 |
+
params.dA_ptr = dA.data_ptr();
|
| 142 |
+
params.dB_ptr = dB.data_ptr();
|
| 143 |
+
params.dC_ptr = dC.data_ptr();
|
| 144 |
+
params.dD_ptr = dD_ptr;
|
| 145 |
+
params.ddelta_ptr = ddelta.data_ptr();
|
| 146 |
+
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
| 147 |
+
// All stride are in elements, not bytes.
|
| 148 |
+
params.dout_batch_stride = dout.stride(0);
|
| 149 |
+
params.dout_d_stride = dout.stride(1);
|
| 150 |
+
params.dA_d_stride = dA.stride(0);
|
| 151 |
+
params.dA_dstate_stride = dA.stride(1);
|
| 152 |
+
params.dB_batch_stride = dB.stride(0);
|
| 153 |
+
params.dB_group_stride = dB.stride(1);
|
| 154 |
+
params.dB_dstate_stride = dB.stride(2);
|
| 155 |
+
params.dC_batch_stride = dC.stride(0);
|
| 156 |
+
params.dC_group_stride = dC.stride(1);
|
| 157 |
+
params.dC_dstate_stride = dC.stride(2);
|
| 158 |
+
params.du_batch_stride = du.stride(0);
|
| 159 |
+
params.du_d_stride = du.stride(1);
|
| 160 |
+
params.ddelta_batch_stride = ddelta.stride(0);
|
| 161 |
+
params.ddelta_d_stride = ddelta.stride(1);
|
| 162 |
+
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
std::vector<at::Tensor>
|
| 166 |
+
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
| 167 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 168 |
+
const c10::optional<at::Tensor> &D_,
|
| 169 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 170 |
+
bool delta_softplus,
|
| 171 |
+
int nrows
|
| 172 |
+
) {
|
| 173 |
+
auto input_type = u.scalar_type();
|
| 174 |
+
auto weight_type = A.scalar_type();
|
| 175 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 176 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 177 |
+
|
| 178 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 179 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 180 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 181 |
+
|
| 182 |
+
TORCH_CHECK(u.is_cuda());
|
| 183 |
+
TORCH_CHECK(delta.is_cuda());
|
| 184 |
+
TORCH_CHECK(A.is_cuda());
|
| 185 |
+
TORCH_CHECK(B.is_cuda());
|
| 186 |
+
TORCH_CHECK(C.is_cuda());
|
| 187 |
+
|
| 188 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 189 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 190 |
+
|
| 191 |
+
const auto sizes = u.sizes();
|
| 192 |
+
const int batch_size = sizes[0];
|
| 193 |
+
const int dim = sizes[1];
|
| 194 |
+
const int seqlen = sizes[2];
|
| 195 |
+
const int dstate = A.size(1);
|
| 196 |
+
const int n_groups = B.size(1);
|
| 197 |
+
|
| 198 |
+
TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows");
|
| 199 |
+
TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows");
|
| 200 |
+
|
| 201 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 202 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 203 |
+
CHECK_SHAPE(A, dim, dstate);
|
| 204 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
| 205 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 206 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
| 207 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 208 |
+
|
| 209 |
+
if (D_.has_value()) {
|
| 210 |
+
auto D = D_.value();
|
| 211 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 212 |
+
TORCH_CHECK(D.is_cuda());
|
| 213 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 214 |
+
CHECK_SHAPE(D, dim);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
if (delta_bias_.has_value()) {
|
| 218 |
+
auto delta_bias = delta_bias_.value();
|
| 219 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 220 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 221 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 222 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
|
| 226 |
+
at::Tensor out = torch::empty_like(delta);
|
| 227 |
+
at::Tensor x;
|
| 228 |
+
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
| 229 |
+
|
| 230 |
+
SSMParamsBase params;
|
| 231 |
+
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
|
| 232 |
+
u, delta, A, B, C, out,
|
| 233 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 234 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 235 |
+
x.data_ptr(),
|
| 236 |
+
delta_softplus);
|
| 237 |
+
|
| 238 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 239 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 240 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 241 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 242 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
| 243 |
+
INT_SWITCH(nrows, kNRows, [&] {
|
| 244 |
+
selective_scan_fwd_cuda<kNRows, input_t, weight_t>(params, stream);
|
| 245 |
+
});
|
| 246 |
+
});
|
| 247 |
+
std::vector<at::Tensor> result = {out, x};
|
| 248 |
+
return result;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
std::vector<at::Tensor>
|
| 252 |
+
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
| 253 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 254 |
+
const c10::optional<at::Tensor> &D_,
|
| 255 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 256 |
+
const at::Tensor &dout,
|
| 257 |
+
const c10::optional<at::Tensor> &x_,
|
| 258 |
+
bool delta_softplus,
|
| 259 |
+
int nrows
|
| 260 |
+
) {
|
| 261 |
+
auto input_type = u.scalar_type();
|
| 262 |
+
auto weight_type = A.scalar_type();
|
| 263 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 264 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 265 |
+
|
| 266 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 267 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 268 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 269 |
+
TORCH_CHECK(dout.scalar_type() == input_type);
|
| 270 |
+
|
| 271 |
+
TORCH_CHECK(u.is_cuda());
|
| 272 |
+
TORCH_CHECK(delta.is_cuda());
|
| 273 |
+
TORCH_CHECK(A.is_cuda());
|
| 274 |
+
TORCH_CHECK(B.is_cuda());
|
| 275 |
+
TORCH_CHECK(C.is_cuda());
|
| 276 |
+
TORCH_CHECK(dout.is_cuda());
|
| 277 |
+
|
| 278 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 279 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 280 |
+
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
| 281 |
+
|
| 282 |
+
const auto sizes = u.sizes();
|
| 283 |
+
const int batch_size = sizes[0];
|
| 284 |
+
const int dim = sizes[1];
|
| 285 |
+
const int seqlen = sizes[2];
|
| 286 |
+
const int dstate = A.size(1);
|
| 287 |
+
const int n_groups = B.size(1);
|
| 288 |
+
|
| 289 |
+
TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows");
|
| 290 |
+
TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows");
|
| 291 |
+
|
| 292 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 293 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 294 |
+
CHECK_SHAPE(A, dim, dstate);
|
| 295 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
| 296 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 297 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
| 298 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 299 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 300 |
+
|
| 301 |
+
if (D_.has_value()) {
|
| 302 |
+
auto D = D_.value();
|
| 303 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 304 |
+
TORCH_CHECK(D.is_cuda());
|
| 305 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 306 |
+
CHECK_SHAPE(D, dim);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
if (delta_bias_.has_value()) {
|
| 310 |
+
auto delta_bias = delta_bias_.value();
|
| 311 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 312 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 313 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 314 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
at::Tensor out;
|
| 318 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
| 319 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
| 320 |
+
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
| 321 |
+
if (x_.has_value()) {
|
| 322 |
+
auto x = x_.value();
|
| 323 |
+
TORCH_CHECK(x.scalar_type() == weight_type);
|
| 324 |
+
TORCH_CHECK(x.is_cuda());
|
| 325 |
+
TORCH_CHECK(x.is_contiguous());
|
| 326 |
+
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
at::Tensor du = torch::empty_like(u);
|
| 330 |
+
at::Tensor ddelta = torch::empty_like(delta);
|
| 331 |
+
at::Tensor dA = torch::zeros_like(A);
|
| 332 |
+
at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
| 333 |
+
at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
| 334 |
+
at::Tensor dD;
|
| 335 |
+
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
| 336 |
+
at::Tensor ddelta_bias;
|
| 337 |
+
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
| 338 |
+
|
| 339 |
+
SSMParamsBwd params;
|
| 340 |
+
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
|
| 341 |
+
u, delta, A, B, C, out,
|
| 342 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 343 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 344 |
+
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
| 345 |
+
dout, du, ddelta, dA, dB, dC,
|
| 346 |
+
D_.has_value() ? dD.data_ptr() : nullptr,
|
| 347 |
+
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
| 348 |
+
delta_softplus);
|
| 349 |
+
|
| 350 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 351 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 352 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 353 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 354 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
| 355 |
+
// constexpr int kNRows = 1;
|
| 356 |
+
INT_SWITCH(nrows, kNRows, [&] {
|
| 357 |
+
selective_scan_bwd_cuda<kNRows, input_t, weight_t>(params, stream);
|
| 358 |
+
});
|
| 359 |
+
});
|
| 360 |
+
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
| 361 |
+
return result;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 365 |
+
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
| 366 |
+
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
| 367 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
| 11 |
+
|
| 12 |
+
#include <cub/block/block_load.cuh>
|
| 13 |
+
#include <cub/block/block_store.cuh>
|
| 14 |
+
#include <cub/block/block_scan.cuh>
|
| 15 |
+
#include <cub/block/block_reduce.cuh>
|
| 16 |
+
|
| 17 |
+
#include "selective_scan.h"
|
| 18 |
+
#include "selective_scan_common.h"
|
| 19 |
+
#include "reverse_scan.cuh"
|
| 20 |
+
#include "static_switch.h"
|
| 21 |
+
|
| 22 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kDeltaSoftplus_, typename input_t_, typename weight_t_, typename output_t_>
|
| 23 |
+
struct Selective_Scan_bwd_kernel_traits {
|
| 24 |
+
static_assert(kNItems_ % 4 == 0);
|
| 25 |
+
using input_t = input_t_;
|
| 26 |
+
using weight_t = weight_t_;
|
| 27 |
+
using output_t = output_t_;
|
| 28 |
+
|
| 29 |
+
static constexpr int kNThreads = kNThreads_;
|
| 30 |
+
static constexpr int kNItems = kNItems_;
|
| 31 |
+
static constexpr int MaxDState = MAX_DSTATE;
|
| 32 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 33 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 34 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 35 |
+
static_assert(kNItems % kNElts == 0);
|
| 36 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 37 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 38 |
+
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
| 39 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
| 40 |
+
// For complex this would lead to massive register spilling, so we keep it at 2.
|
| 41 |
+
static constexpr int kMinBlocks = kNThreads == 128 && 3;
|
| 42 |
+
static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes;
|
| 43 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 44 |
+
using scan_t = float2;
|
| 45 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 46 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 47 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 48 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 49 |
+
using BlockLoadOutputT = cub::BlockLoad<output_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 50 |
+
using BlockLoadOutputVecT = cub::BlockLoad<vec_t, kNThreads, kNLoadsOutput, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 51 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 52 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 53 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 54 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 55 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 56 |
+
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
| 57 |
+
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
| 58 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 59 |
+
using BlockExchangeT = cub::BlockExchange<float, kNThreads, kNItems>;
|
| 60 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 61 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 62 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 63 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 64 |
+
sizeof(typename BlockLoadOutputT::TempStorage),
|
| 65 |
+
sizeof(typename BlockLoadOutputVecT::TempStorage),
|
| 66 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 67 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
| 68 |
+
static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage);
|
| 69 |
+
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
| 70 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template<typename Ktraits>
|
| 74 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 75 |
+
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
| 76 |
+
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
| 77 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 78 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 79 |
+
using input_t = typename Ktraits::input_t;
|
| 80 |
+
using weight_t = typename Ktraits::weight_t;
|
| 81 |
+
using output_t = typename Ktraits::output_t;
|
| 82 |
+
using scan_t = typename Ktraits::scan_t;
|
| 83 |
+
|
| 84 |
+
// Shared memory.
|
| 85 |
+
extern __shared__ char smem_[];
|
| 86 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 87 |
+
auto& smem_load1 = reinterpret_cast<typename Ktraits::BlockLoadOutputT::TempStorage&>(smem_);
|
| 88 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 89 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 90 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 91 |
+
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 92 |
+
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
| 93 |
+
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
| 94 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
| 95 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
| 96 |
+
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
| 97 |
+
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
| 98 |
+
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads);
|
| 99 |
+
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + Ktraits::MaxDState);
|
| 100 |
+
|
| 101 |
+
const int batch_id = blockIdx.x;
|
| 102 |
+
const int dim_id = blockIdx.y;
|
| 103 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
| 104 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 105 |
+
+ dim_id * params.u_d_stride;
|
| 106 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 107 |
+
+ dim_id * params.delta_d_stride;
|
| 108 |
+
|
| 109 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
| 110 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 111 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 112 |
+
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
|
| 113 |
+
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
| 114 |
+
+ (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride);
|
| 115 |
+
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
| 116 |
+
+ (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride);
|
| 117 |
+
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
|
| 118 |
+
float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
| 119 |
+
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
|
| 120 |
+
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
| 121 |
+
scan_t *x = params.x_ptr == nullptr
|
| 122 |
+
? nullptr
|
| 123 |
+
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
|
| 124 |
+
float dD_val = 0;
|
| 125 |
+
float ddelta_bias_val = 0;
|
| 126 |
+
|
| 127 |
+
output_t *dout = reinterpret_cast<output_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride + dim_id * params.dout_d_stride;
|
| 128 |
+
|
| 129 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 130 |
+
u += (params.n_chunks - 1) * kChunkSize;
|
| 131 |
+
delta += (params.n_chunks - 1) * kChunkSize;
|
| 132 |
+
dout += (params.n_chunks - 1) * kChunkSize;
|
| 133 |
+
Bvar += (params.n_chunks - 1) * kChunkSize;
|
| 134 |
+
Cvar += (params.n_chunks - 1) * kChunkSize;
|
| 135 |
+
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
| 136 |
+
input_t u_vals[kNItems];
|
| 137 |
+
input_t delta_vals_load[kNItems];
|
| 138 |
+
float dout_vals[kNItems];
|
| 139 |
+
__syncthreads();
|
| 140 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
| 141 |
+
__syncthreads();
|
| 142 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 143 |
+
__syncthreads();
|
| 144 |
+
if constexpr (std::is_same_v<output_t, input_t>) {
|
| 145 |
+
input_t dout_vals_load[kNItems];
|
| 146 |
+
load_input<Ktraits>(reinterpret_cast<input_t *>(dout), dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 147 |
+
Converter<typename Ktraits::input_t, kNItems>::to_float(dout_vals_load, dout_vals);
|
| 148 |
+
} else {
|
| 149 |
+
static_assert(std::is_same_v<output_t, float>);
|
| 150 |
+
load_output<Ktraits>(dout, dout_vals, smem_load1, params.seqlen - chunk * kChunkSize);
|
| 151 |
+
}
|
| 152 |
+
u -= kChunkSize;
|
| 153 |
+
// Will reload delta at the same location if kDeltaSoftplus
|
| 154 |
+
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
| 155 |
+
dout -= kChunkSize;
|
| 156 |
+
|
| 157 |
+
float delta_vals[kNItems];
|
| 158 |
+
float du_vals[kNItems];
|
| 159 |
+
#pragma unroll
|
| 160 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 161 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
| 162 |
+
if constexpr (kDeltaSoftplus) {
|
| 163 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
|
| 168 |
+
#pragma unroll
|
| 169 |
+
for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
|
| 170 |
+
|
| 171 |
+
float ddelta_vals[kNItems] = {0};
|
| 172 |
+
__syncthreads();
|
| 173 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
| 174 |
+
constexpr float kLog2e = M_LOG2E;
|
| 175 |
+
weight_t A_val = A[state_idx * params.A_dstate_stride];
|
| 176 |
+
weight_t A_scaled = A_val * kLog2e;
|
| 177 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 178 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
| 179 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 180 |
+
auto &smem_load_weight_C = smem_load_weight1;
|
| 181 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
| 182 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize));
|
| 183 |
+
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
| 184 |
+
#pragma unroll
|
| 185 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 186 |
+
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
|
| 187 |
+
thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
| 188 |
+
if (i == 0) {
|
| 189 |
+
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState: threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp;
|
| 190 |
+
} else {
|
| 191 |
+
thread_reverse_data[i - 1].x = delta_a_exp;
|
| 192 |
+
}
|
| 193 |
+
thread_reverse_data[i].y = dout_vals[i] * C_vals[i];
|
| 194 |
+
}
|
| 195 |
+
__syncthreads();
|
| 196 |
+
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
| 197 |
+
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState])
|
| 198 |
+
: smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState];
|
| 199 |
+
// Initialize running total
|
| 200 |
+
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
|
| 201 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 202 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 203 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 204 |
+
);
|
| 205 |
+
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
|
| 206 |
+
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
| 207 |
+
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
| 208 |
+
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
| 209 |
+
);
|
| 210 |
+
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
| 211 |
+
weight_t dA_val = 0;
|
| 212 |
+
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
| 213 |
+
#pragma unroll
|
| 214 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 215 |
+
const float dx = thread_reverse_data[i].y;
|
| 216 |
+
const float ddelta_u = dx * B_vals[i];
|
| 217 |
+
du_vals[i] += ddelta_u * delta_vals[i];
|
| 218 |
+
const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
| 219 |
+
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
|
| 220 |
+
dA_val += dx * delta_vals[i] * a;
|
| 221 |
+
dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]);
|
| 222 |
+
dC_vals[i] = dout_vals[i] * thread_data[i].y;
|
| 223 |
+
}
|
| 224 |
+
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
| 225 |
+
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
| 226 |
+
auto &smem_exchange_C = smem_exchange1;
|
| 227 |
+
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
| 228 |
+
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
| 229 |
+
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
| 230 |
+
weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
| 231 |
+
#pragma unroll
|
| 232 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 233 |
+
if (i * kNThreads < seqlen_remaining) {
|
| 234 |
+
{ gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
| 235 |
+
{ gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
| 239 |
+
if (threadIdx.x == 0) {
|
| 240 |
+
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
if constexpr (kDeltaSoftplus) {
|
| 245 |
+
input_t delta_vals_load[kNItems];
|
| 246 |
+
__syncthreads();
|
| 247 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 248 |
+
delta -= kChunkSize;
|
| 249 |
+
#pragma unroll
|
| 250 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 251 |
+
float delta_val = float(delta_vals_load[i]) + delta_bias;
|
| 252 |
+
float delta_val_neg_exp = expf(-delta_val);
|
| 253 |
+
ddelta_vals[i] = delta_val <= 20.f
|
| 254 |
+
? ddelta_vals[i] / (1.f + delta_val_neg_exp)
|
| 255 |
+
: ddelta_vals[i];
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
__syncthreads();
|
| 260 |
+
#pragma unroll
|
| 261 |
+
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
|
| 262 |
+
|
| 263 |
+
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
| 264 |
+
+ dim_id * params.du_d_stride + chunk * kChunkSize;
|
| 265 |
+
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
| 266 |
+
+ dim_id * params.ddelta_d_stride + chunk * kChunkSize;
|
| 267 |
+
__syncthreads();
|
| 268 |
+
store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 269 |
+
__syncthreads();
|
| 270 |
+
store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
| 271 |
+
Bvar -= kChunkSize;
|
| 272 |
+
Cvar -= kChunkSize;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
if (params.dD_ptr != nullptr) {
|
| 276 |
+
__syncthreads();
|
| 277 |
+
dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
|
| 278 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
|
| 279 |
+
}
|
| 280 |
+
if (params.ddelta_bias_ptr != nullptr) {
|
| 281 |
+
__syncthreads();
|
| 282 |
+
ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
|
| 283 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
|
| 284 |
+
}
|
| 285 |
+
__syncthreads();
|
| 286 |
+
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
| 287 |
+
gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename output_t>
|
| 292 |
+
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 293 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 294 |
+
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
| 295 |
+
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kDeltaSoftplus, input_t, weight_t, output_t>;
|
| 296 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t);
|
| 297 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 298 |
+
dim3 grid(params.batch, params.dim);
|
| 299 |
+
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
| 300 |
+
if (kSmemSize >= 48 * 1024) {
|
| 301 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 302 |
+
}
|
| 303 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 304 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 305 |
+
});
|
| 306 |
+
});
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
template<int knrows, typename input_t, typename weight_t, typename output_t>
|
| 310 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
| 311 |
+
if (params.seqlen <= 128) {
|
| 312 |
+
selective_scan_bwd_launch<32, 4, input_t, weight_t, output_t>(params, stream);
|
| 313 |
+
} else if (params.seqlen <= 256) {
|
| 314 |
+
selective_scan_bwd_launch<32, 8, input_t, weight_t, output_t>(params, stream);
|
| 315 |
+
} else if (params.seqlen <= 512) {
|
| 316 |
+
selective_scan_bwd_launch<32, 16, input_t, weight_t, output_t>(params, stream);
|
| 317 |
+
} else if (params.seqlen <= 1024) {
|
| 318 |
+
selective_scan_bwd_launch<64, 16, input_t, weight_t, output_t>(params, stream);
|
| 319 |
+
} else {
|
| 320 |
+
selective_scan_bwd_launch<128, 16, input_t, weight_t, output_t>(params, stream);
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_bwd_kernel_oflex.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_bwd_cuda<1, float, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_bwd_cuda<1, at::Half, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_bwd_cuda<1, at::BFloat16, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 9 |
+
template void selective_scan_bwd_cuda<1, at::Half, float, at::Half>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 10 |
+
template void selective_scan_bwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 11 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
#include "selective_scan_fwd_kernel_oflex.cuh"
|
| 5 |
+
|
| 6 |
+
template void selective_scan_fwd_cuda<1, float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 7 |
+
template void selective_scan_fwd_cuda<1, at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 8 |
+
template void selective_scan_fwd_cuda<1, at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 9 |
+
template void selective_scan_fwd_cuda<1, at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 10 |
+
template void selective_scan_fwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 11 |
+
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <c10/util/BFloat16.h>
|
| 8 |
+
#include <c10/util/Half.h>
|
| 9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 10 |
+
|
| 11 |
+
#include <cub/block/block_load.cuh>
|
| 12 |
+
#include <cub/block/block_store.cuh>
|
| 13 |
+
#include <cub/block/block_scan.cuh>
|
| 14 |
+
|
| 15 |
+
#include "selective_scan.h"
|
| 16 |
+
#include "selective_scan_common.h"
|
| 17 |
+
#include "static_switch.h"
|
| 18 |
+
|
| 19 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, typename input_t_, typename weight_t_, typename output_t_>
|
| 20 |
+
struct Selective_Scan_fwd_kernel_traits {
|
| 21 |
+
static_assert(kNItems_ % 4 == 0);
|
| 22 |
+
using input_t = input_t_;
|
| 23 |
+
using weight_t = weight_t_;
|
| 24 |
+
using output_t = output_t_;
|
| 25 |
+
static constexpr int kNThreads = kNThreads_;
|
| 26 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
| 27 |
+
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
| 28 |
+
static constexpr int kNItems = kNItems_;
|
| 29 |
+
static constexpr int MaxDState = MAX_DSTATE;
|
| 30 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 31 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 32 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
|
| 33 |
+
static_assert(kNItems % kNElts == 0);
|
| 34 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
| 35 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
| 36 |
+
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
| 37 |
+
static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes;
|
| 38 |
+
static constexpr bool kDirectIOOutput = kDirectIO && (kNLoadsOutput == 1);
|
| 39 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 40 |
+
using scan_t = float2;
|
| 41 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 42 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 43 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 44 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 45 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
| 46 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
| 47 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 48 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
| 49 |
+
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
| 50 |
+
using BlockStoreOutputT = cub::BlockStore<output_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 51 |
+
using BlockStoreOutputVecT = cub::BlockStore<vec_t, kNThreads, kNLoadsOutput,
|
| 52 |
+
!kDirectIOOutput ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
| 53 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
| 54 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
| 55 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
| 56 |
+
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 57 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
| 58 |
+
2 * sizeof(typename BlockLoadWeightT::TempStorage),
|
| 59 |
+
2 * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
| 60 |
+
sizeof(typename BlockStoreT::TempStorage),
|
| 61 |
+
sizeof(typename BlockStoreVecT::TempStorage),
|
| 62 |
+
sizeof(typename BlockStoreOutputT::TempStorage),
|
| 63 |
+
sizeof(typename BlockStoreOutputVecT::TempStorage)});
|
| 64 |
+
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
template<typename Ktraits>
|
| 68 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
| 69 |
+
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
| 70 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 71 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 72 |
+
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
| 73 |
+
using input_t = typename Ktraits::input_t;
|
| 74 |
+
using weight_t = typename Ktraits::weight_t;
|
| 75 |
+
using output_t = typename Ktraits::output_t;
|
| 76 |
+
using scan_t = typename Ktraits::scan_t;
|
| 77 |
+
|
| 78 |
+
// Shared memory.
|
| 79 |
+
extern __shared__ char smem_[];
|
| 80 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 81 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
| 82 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
| 83 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 84 |
+
auto& smem_store1 = reinterpret_cast<typename Ktraits::BlockStoreOutputT::TempStorage&>(smem_);
|
| 85 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 86 |
+
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
| 87 |
+
|
| 88 |
+
const int batch_id = blockIdx.x;
|
| 89 |
+
const int dim_id = blockIdx.y;
|
| 90 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
| 91 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
| 92 |
+
+ dim_id * params.u_d_stride;
|
| 93 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
| 94 |
+
+ dim_id * params.delta_d_stride;
|
| 95 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
| 96 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
| 97 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
| 98 |
+
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate;
|
| 99 |
+
|
| 100 |
+
float D_val = 0; // attention!
|
| 101 |
+
if (params.D_ptr != nullptr) {
|
| 102 |
+
D_val = reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
| 103 |
+
}
|
| 104 |
+
float delta_bias = 0;
|
| 105 |
+
if (params.delta_bias_ptr != nullptr) {
|
| 106 |
+
delta_bias = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
| 110 |
+
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
| 111 |
+
input_t u_vals[kNItems], delta_vals_load[kNItems];
|
| 112 |
+
__syncthreads();
|
| 113 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
| 114 |
+
if constexpr (!kDirectIO) { __syncthreads(); }
|
| 115 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
| 116 |
+
u += kChunkSize;
|
| 117 |
+
delta += kChunkSize;
|
| 118 |
+
|
| 119 |
+
float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems];
|
| 120 |
+
#pragma unroll
|
| 121 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 122 |
+
float u_val = float(u_vals[i]);
|
| 123 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
| 124 |
+
if (params.delta_softplus) {
|
| 125 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
| 126 |
+
}
|
| 127 |
+
delta_u_vals[i] = delta_vals[i] * u_val;
|
| 128 |
+
out_vals[i] = D_val * u_val;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
__syncthreads();
|
| 132 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
| 133 |
+
constexpr float kLog2e = M_LOG2E;
|
| 134 |
+
weight_t A_val = A[state_idx * params.A_dstate_stride];
|
| 135 |
+
A_val *= kLog2e;
|
| 136 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
| 137 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
| 138 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize));
|
| 139 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
| 140 |
+
smem_load_weight1, (params.seqlen - chunk * kChunkSize));
|
| 141 |
+
__syncthreads();
|
| 142 |
+
scan_t thread_data[kNItems];
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 145 |
+
thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]);
|
| 146 |
+
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
| 147 |
+
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
| 148 |
+
thread_data[i] = make_float2(1.f, 0.f);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
// Initialize running total
|
| 153 |
+
scan_t running_prefix;
|
| 154 |
+
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
| 155 |
+
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
| 156 |
+
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
| 157 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
| 158 |
+
Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
| 159 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
| 160 |
+
);
|
| 161 |
+
// There's a syncthreads in the scan op, so we don't need to sync here.
|
| 162 |
+
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
| 163 |
+
if (threadIdx.x == 0) {
|
| 164 |
+
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
| 165 |
+
x[chunk * params.dstate + state_idx] = prefix_op.running_prefix;
|
| 166 |
+
}
|
| 167 |
+
#pragma unroll
|
| 168 |
+
for (int i = 0; i < kNItems; ++i) {
|
| 169 |
+
out_vals[i] += thread_data[i].y * C_vals[i];
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
output_t *out = reinterpret_cast<output_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 174 |
+
+ dim_id * params.out_d_stride + chunk * kChunkSize;
|
| 175 |
+
__syncthreads();
|
| 176 |
+
store_output1<Ktraits>(out, out_vals, smem_store1, params.seqlen - chunk * kChunkSize);
|
| 177 |
+
Bvar += kChunkSize;
|
| 178 |
+
Cvar += kChunkSize;
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename output_t>
|
| 183 |
+
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 184 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
| 185 |
+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, input_t, weight_t, output_t>;
|
| 186 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t);
|
| 187 |
+
// printf("smem_size = %d\n", kSmemSize);
|
| 188 |
+
dim3 grid(params.batch, params.dim);
|
| 189 |
+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
| 190 |
+
if (kSmemSize >= 48 * 1024) {
|
| 191 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 192 |
+
}
|
| 193 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 194 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 195 |
+
});
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template<int knrows, typename input_t, typename weight_t, typename output_t>
|
| 199 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
| 200 |
+
if (params.seqlen <= 128) {
|
| 201 |
+
selective_scan_fwd_launch<32, 4, input_t, weight_t, output_t>(params, stream);
|
| 202 |
+
} else if (params.seqlen <= 256) {
|
| 203 |
+
selective_scan_fwd_launch<32, 8, input_t, weight_t, output_t>(params, stream);
|
| 204 |
+
} else if (params.seqlen <= 512) {
|
| 205 |
+
selective_scan_fwd_launch<32, 16, input_t, weight_t, output_t>(params, stream);
|
| 206 |
+
} else if (params.seqlen <= 1024) {
|
| 207 |
+
selective_scan_fwd_launch<64, 16, input_t, weight_t, output_t>(params, stream);
|
| 208 |
+
} else {
|
| 209 |
+
selective_scan_fwd_launch<128, 16, input_t, weight_t, output_t>(params, stream);
|
| 210 |
+
}
|
| 211 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "selective_scan.h"
|
| 11 |
+
#define MAX_DSTATE 256
|
| 12 |
+
|
| 13 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 14 |
+
using weight_t = float;
|
| 15 |
+
|
| 16 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 17 |
+
if (ITYPE == at::ScalarType::Half) { \
|
| 18 |
+
using input_t = at::Half; \
|
| 19 |
+
__VA_ARGS__(); \
|
| 20 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 21 |
+
using input_t = at::BFloat16; \
|
| 22 |
+
__VA_ARGS__(); \
|
| 23 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
| 24 |
+
using input_t = float; \
|
| 25 |
+
__VA_ARGS__(); \
|
| 26 |
+
} else { \
|
| 27 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template<int knrows, typename input_t, typename weight_t, typename output_t>
|
| 31 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
| 32 |
+
|
| 33 |
+
template <int knrows, typename input_t, typename weight_t, typename output_t>
|
| 34 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
| 35 |
+
|
| 36 |
+
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
| 37 |
+
// sizes
|
| 38 |
+
const size_t batch,
|
| 39 |
+
const size_t dim,
|
| 40 |
+
const size_t seqlen,
|
| 41 |
+
const size_t dstate,
|
| 42 |
+
const size_t n_groups,
|
| 43 |
+
const size_t n_chunks,
|
| 44 |
+
// device pointers
|
| 45 |
+
const at::Tensor u,
|
| 46 |
+
const at::Tensor delta,
|
| 47 |
+
const at::Tensor A,
|
| 48 |
+
const at::Tensor B,
|
| 49 |
+
const at::Tensor C,
|
| 50 |
+
const at::Tensor out,
|
| 51 |
+
void* D_ptr,
|
| 52 |
+
void* delta_bias_ptr,
|
| 53 |
+
void* x_ptr,
|
| 54 |
+
bool delta_softplus) {
|
| 55 |
+
|
| 56 |
+
// Reset the parameters
|
| 57 |
+
memset(¶ms, 0, sizeof(params));
|
| 58 |
+
|
| 59 |
+
params.batch = batch;
|
| 60 |
+
params.dim = dim;
|
| 61 |
+
params.seqlen = seqlen;
|
| 62 |
+
params.dstate = dstate;
|
| 63 |
+
params.n_groups = n_groups;
|
| 64 |
+
params.n_chunks = n_chunks;
|
| 65 |
+
params.dim_ngroups_ratio = dim / n_groups;
|
| 66 |
+
|
| 67 |
+
params.delta_softplus = delta_softplus;
|
| 68 |
+
|
| 69 |
+
// Set the pointers and strides.
|
| 70 |
+
params.u_ptr = u.data_ptr();
|
| 71 |
+
params.delta_ptr = delta.data_ptr();
|
| 72 |
+
params.A_ptr = A.data_ptr();
|
| 73 |
+
params.B_ptr = B.data_ptr();
|
| 74 |
+
params.C_ptr = C.data_ptr();
|
| 75 |
+
params.D_ptr = D_ptr;
|
| 76 |
+
params.delta_bias_ptr = delta_bias_ptr;
|
| 77 |
+
params.out_ptr = out.data_ptr();
|
| 78 |
+
params.x_ptr = x_ptr;
|
| 79 |
+
|
| 80 |
+
// All stride are in elements, not bytes.
|
| 81 |
+
params.A_d_stride = A.stride(0);
|
| 82 |
+
params.A_dstate_stride = A.stride(1);
|
| 83 |
+
params.B_batch_stride = B.stride(0);
|
| 84 |
+
params.B_group_stride = B.stride(1);
|
| 85 |
+
params.B_dstate_stride = B.stride(2);
|
| 86 |
+
params.C_batch_stride = C.stride(0);
|
| 87 |
+
params.C_group_stride = C.stride(1);
|
| 88 |
+
params.C_dstate_stride = C.stride(2);
|
| 89 |
+
params.u_batch_stride = u.stride(0);
|
| 90 |
+
params.u_d_stride = u.stride(1);
|
| 91 |
+
params.delta_batch_stride = delta.stride(0);
|
| 92 |
+
params.delta_d_stride = delta.stride(1);
|
| 93 |
+
|
| 94 |
+
params.out_batch_stride = out.stride(0);
|
| 95 |
+
params.out_d_stride = out.stride(1);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
| 99 |
+
// sizes
|
| 100 |
+
const size_t batch,
|
| 101 |
+
const size_t dim,
|
| 102 |
+
const size_t seqlen,
|
| 103 |
+
const size_t dstate,
|
| 104 |
+
const size_t n_groups,
|
| 105 |
+
const size_t n_chunks,
|
| 106 |
+
// device pointers
|
| 107 |
+
const at::Tensor u,
|
| 108 |
+
const at::Tensor delta,
|
| 109 |
+
const at::Tensor A,
|
| 110 |
+
const at::Tensor B,
|
| 111 |
+
const at::Tensor C,
|
| 112 |
+
const at::Tensor out,
|
| 113 |
+
void* D_ptr,
|
| 114 |
+
void* delta_bias_ptr,
|
| 115 |
+
void* x_ptr,
|
| 116 |
+
const at::Tensor dout,
|
| 117 |
+
const at::Tensor du,
|
| 118 |
+
const at::Tensor ddelta,
|
| 119 |
+
const at::Tensor dA,
|
| 120 |
+
const at::Tensor dB,
|
| 121 |
+
const at::Tensor dC,
|
| 122 |
+
void* dD_ptr,
|
| 123 |
+
void* ddelta_bias_ptr,
|
| 124 |
+
bool delta_softplus) {
|
| 125 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
| 126 |
+
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks,
|
| 127 |
+
u, delta, A, B, C, dout,
|
| 128 |
+
D_ptr, delta_bias_ptr, x_ptr, delta_softplus);
|
| 129 |
+
|
| 130 |
+
// Set the pointers and strides.
|
| 131 |
+
params.dout_ptr = dout.data_ptr();
|
| 132 |
+
params.du_ptr = du.data_ptr();
|
| 133 |
+
params.dA_ptr = dA.data_ptr();
|
| 134 |
+
params.dB_ptr = dB.data_ptr();
|
| 135 |
+
params.dC_ptr = dC.data_ptr();
|
| 136 |
+
params.dD_ptr = dD_ptr;
|
| 137 |
+
params.ddelta_ptr = ddelta.data_ptr();
|
| 138 |
+
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
| 139 |
+
// All stride are in elements, not bytes.
|
| 140 |
+
params.dout_batch_stride = dout.stride(0);
|
| 141 |
+
params.dout_d_stride = dout.stride(1);
|
| 142 |
+
params.dA_d_stride = dA.stride(0);
|
| 143 |
+
params.dA_dstate_stride = dA.stride(1);
|
| 144 |
+
params.dB_batch_stride = dB.stride(0);
|
| 145 |
+
params.dB_group_stride = dB.stride(1);
|
| 146 |
+
params.dB_dstate_stride = dB.stride(2);
|
| 147 |
+
params.dC_batch_stride = dC.stride(0);
|
| 148 |
+
params.dC_group_stride = dC.stride(1);
|
| 149 |
+
params.dC_dstate_stride = dC.stride(2);
|
| 150 |
+
params.du_batch_stride = du.stride(0);
|
| 151 |
+
params.du_d_stride = du.stride(1);
|
| 152 |
+
params.ddelta_batch_stride = ddelta.stride(0);
|
| 153 |
+
params.ddelta_d_stride = ddelta.stride(1);
|
| 154 |
+
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
std::vector<at::Tensor>
|
| 158 |
+
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
| 159 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 160 |
+
const c10::optional<at::Tensor> &D_,
|
| 161 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 162 |
+
bool delta_softplus,
|
| 163 |
+
int nrows,
|
| 164 |
+
bool out_float
|
| 165 |
+
) {
|
| 166 |
+
auto input_type = u.scalar_type();
|
| 167 |
+
auto weight_type = A.scalar_type();
|
| 168 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 169 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 170 |
+
|
| 171 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 172 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 173 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 174 |
+
|
| 175 |
+
TORCH_CHECK(u.is_cuda());
|
| 176 |
+
TORCH_CHECK(delta.is_cuda());
|
| 177 |
+
TORCH_CHECK(A.is_cuda());
|
| 178 |
+
TORCH_CHECK(B.is_cuda());
|
| 179 |
+
TORCH_CHECK(C.is_cuda());
|
| 180 |
+
|
| 181 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 182 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 183 |
+
|
| 184 |
+
const auto sizes = u.sizes();
|
| 185 |
+
const int batch_size = sizes[0];
|
| 186 |
+
const int dim = sizes[1];
|
| 187 |
+
const int seqlen = sizes[2];
|
| 188 |
+
const int dstate = A.size(1);
|
| 189 |
+
const int n_groups = B.size(1);
|
| 190 |
+
|
| 191 |
+
TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
|
| 192 |
+
TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
|
| 193 |
+
|
| 194 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 195 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 196 |
+
CHECK_SHAPE(A, dim, dstate);
|
| 197 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
| 198 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 199 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
| 200 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 201 |
+
|
| 202 |
+
if (D_.has_value()) {
|
| 203 |
+
auto D = D_.value();
|
| 204 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 205 |
+
TORCH_CHECK(D.is_cuda());
|
| 206 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 207 |
+
CHECK_SHAPE(D, dim);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
if (delta_bias_.has_value()) {
|
| 211 |
+
auto delta_bias = delta_bias_.value();
|
| 212 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 213 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 214 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 215 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel
|
| 219 |
+
at::Tensor out = torch::empty({batch_size, dim, seqlen}, u.options().dtype(out_float? (at::ScalarType::Float): input_type));
|
| 220 |
+
at::Tensor x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
| 221 |
+
|
| 222 |
+
SSMParamsBase params;
|
| 223 |
+
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
|
| 224 |
+
u, delta, A, B, C, out,
|
| 225 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 226 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 227 |
+
x.data_ptr(),
|
| 228 |
+
delta_softplus);
|
| 229 |
+
|
| 230 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 231 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 232 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 233 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 234 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
| 235 |
+
if (!out_float) {
|
| 236 |
+
selective_scan_fwd_cuda<1, input_t, weight_t, input_t>(params, stream);
|
| 237 |
+
} else {
|
| 238 |
+
selective_scan_fwd_cuda<1, input_t, weight_t, float>(params, stream);
|
| 239 |
+
}
|
| 240 |
+
});
|
| 241 |
+
std::vector<at::Tensor> result = {out, x};
|
| 242 |
+
return result;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
std::vector<at::Tensor>
|
| 246 |
+
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
| 247 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
| 248 |
+
const c10::optional<at::Tensor> &D_,
|
| 249 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
| 250 |
+
const at::Tensor &dout,
|
| 251 |
+
const c10::optional<at::Tensor> &x_,
|
| 252 |
+
bool delta_softplus,
|
| 253 |
+
int nrows
|
| 254 |
+
) {
|
| 255 |
+
auto input_type = u.scalar_type();
|
| 256 |
+
auto weight_type = A.scalar_type();
|
| 257 |
+
auto output_type = dout.scalar_type();
|
| 258 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 259 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
| 260 |
+
TORCH_CHECK(output_type == input_type || output_type == at::ScalarType::Float);
|
| 261 |
+
|
| 262 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
| 263 |
+
TORCH_CHECK(B.scalar_type() == input_type);
|
| 264 |
+
TORCH_CHECK(C.scalar_type() == input_type);
|
| 265 |
+
|
| 266 |
+
TORCH_CHECK(u.is_cuda());
|
| 267 |
+
TORCH_CHECK(delta.is_cuda());
|
| 268 |
+
TORCH_CHECK(A.is_cuda());
|
| 269 |
+
TORCH_CHECK(B.is_cuda());
|
| 270 |
+
TORCH_CHECK(C.is_cuda());
|
| 271 |
+
TORCH_CHECK(dout.is_cuda());
|
| 272 |
+
|
| 273 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
| 274 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
| 275 |
+
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
| 276 |
+
|
| 277 |
+
const auto sizes = u.sizes();
|
| 278 |
+
const int batch_size = sizes[0];
|
| 279 |
+
const int dim = sizes[1];
|
| 280 |
+
const int seqlen = sizes[2];
|
| 281 |
+
const int dstate = A.size(1);
|
| 282 |
+
const int n_groups = B.size(1);
|
| 283 |
+
|
| 284 |
+
TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups");
|
| 285 |
+
TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256");
|
| 286 |
+
|
| 287 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
| 288 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
| 289 |
+
CHECK_SHAPE(A, dim, dstate);
|
| 290 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
| 291 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
| 292 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
| 293 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
| 294 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 295 |
+
|
| 296 |
+
if (D_.has_value()) {
|
| 297 |
+
auto D = D_.value();
|
| 298 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
| 299 |
+
TORCH_CHECK(D.is_cuda());
|
| 300 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
| 301 |
+
CHECK_SHAPE(D, dim);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
if (delta_bias_.has_value()) {
|
| 305 |
+
auto delta_bias = delta_bias_.value();
|
| 306 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
| 307 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
| 308 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
| 309 |
+
CHECK_SHAPE(delta_bias, dim);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
at::Tensor out;
|
| 313 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
| 314 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
| 315 |
+
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
| 316 |
+
if (x_.has_value()) {
|
| 317 |
+
auto x = x_.value();
|
| 318 |
+
TORCH_CHECK(x.scalar_type() == weight_type);
|
| 319 |
+
TORCH_CHECK(x.is_cuda());
|
| 320 |
+
TORCH_CHECK(x.is_contiguous());
|
| 321 |
+
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
at::Tensor du = torch::empty_like(u);
|
| 325 |
+
at::Tensor ddelta = torch::empty_like(delta);
|
| 326 |
+
at::Tensor dA = torch::zeros_like(A);
|
| 327 |
+
at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
| 328 |
+
at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
| 329 |
+
at::Tensor dD;
|
| 330 |
+
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
| 331 |
+
at::Tensor ddelta_bias;
|
| 332 |
+
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
| 333 |
+
|
| 334 |
+
SSMParamsBwd params;
|
| 335 |
+
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
|
| 336 |
+
u, delta, A, B, C, out,
|
| 337 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
| 338 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
| 339 |
+
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
| 340 |
+
dout, du, ddelta, dA, dB, dC,
|
| 341 |
+
D_.has_value() ? dD.data_ptr() : nullptr,
|
| 342 |
+
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
| 343 |
+
delta_softplus);
|
| 344 |
+
|
| 345 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 346 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 347 |
+
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
| 348 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 349 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
| 350 |
+
if (output_type == input_type) {
|
| 351 |
+
selective_scan_bwd_cuda<1, input_t, weight_t, input_t>(params, stream);
|
| 352 |
+
} else {
|
| 353 |
+
selective_scan_bwd_cuda<1, input_t, weight_t, float>(params, stream);
|
| 354 |
+
}
|
| 355 |
+
});
|
| 356 |
+
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
| 357 |
+
return result;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 361 |
+
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
| 362 |
+
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
| 363 |
+
}
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <cub/config.cuh>
|
| 8 |
+
|
| 9 |
+
#include <cub/util_ptx.cuh>
|
| 10 |
+
#include <cub/util_type.cuh>
|
| 11 |
+
#include <cub/block/block_raking_layout.cuh>
|
| 12 |
+
// #include <cub/detail/uninitialized_copy.cuh>
|
| 13 |
+
#include "uninitialized_copy.cuh"
|
| 14 |
+
#include "cub_extra.cuh"
|
| 15 |
+
|
| 16 |
+
/**
|
| 17 |
+
* Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
|
| 18 |
+
*/
|
| 19 |
+
template <
|
| 20 |
+
int LENGTH,
|
| 21 |
+
typename T,
|
| 22 |
+
typename ReductionOp>
|
| 23 |
+
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
|
| 24 |
+
static_assert(LENGTH > 0);
|
| 25 |
+
T retval = input[LENGTH - 1];
|
| 26 |
+
#pragma unroll
|
| 27 |
+
for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
|
| 28 |
+
return retval;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
/**
|
| 32 |
+
* Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
| 33 |
+
*/
|
| 34 |
+
template <
|
| 35 |
+
int LENGTH,
|
| 36 |
+
typename T,
|
| 37 |
+
typename ScanOp>
|
| 38 |
+
__device__ __forceinline__ T ThreadReverseScanInclusive(
|
| 39 |
+
const T (&input)[LENGTH],
|
| 40 |
+
T (&output)[LENGTH],
|
| 41 |
+
ScanOp scan_op,
|
| 42 |
+
const T postfix)
|
| 43 |
+
{
|
| 44 |
+
T inclusive = postfix;
|
| 45 |
+
#pragma unroll
|
| 46 |
+
for (int i = LENGTH - 1; i >= 0; --i) {
|
| 47 |
+
inclusive = scan_op(inclusive, input[i]);
|
| 48 |
+
output[i] = inclusive;
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
/**
|
| 53 |
+
* Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
| 54 |
+
*/
|
| 55 |
+
template <
|
| 56 |
+
int LENGTH,
|
| 57 |
+
typename T,
|
| 58 |
+
typename ScanOp>
|
| 59 |
+
__device__ __forceinline__ T ThreadReverseScanExclusive(
|
| 60 |
+
const T (&input)[LENGTH],
|
| 61 |
+
T (&output)[LENGTH],
|
| 62 |
+
ScanOp scan_op,
|
| 63 |
+
const T postfix)
|
| 64 |
+
{
|
| 65 |
+
// Careful, output maybe be aliased to input
|
| 66 |
+
T exclusive = postfix;
|
| 67 |
+
T inclusive;
|
| 68 |
+
#pragma unroll
|
| 69 |
+
for (int i = LENGTH - 1; i >= 0; --i) {
|
| 70 |
+
inclusive = scan_op(exclusive, input[i]);
|
| 71 |
+
output[i] = exclusive;
|
| 72 |
+
exclusive = inclusive;
|
| 73 |
+
}
|
| 74 |
+
return inclusive;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
/**
|
| 79 |
+
* \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
|
| 80 |
+
*
|
| 81 |
+
* LOGICAL_WARP_THREADS must be a power-of-two
|
| 82 |
+
*/
|
| 83 |
+
template <
|
| 84 |
+
typename T, ///< Data type being scanned
|
| 85 |
+
int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
|
| 86 |
+
>
|
| 87 |
+
struct WarpReverseScan {
|
| 88 |
+
//---------------------------------------------------------------------
|
| 89 |
+
// Constants and type definitions
|
| 90 |
+
//---------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
/// Whether the logical warp size and the PTX warp size coincide
|
| 93 |
+
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
|
| 94 |
+
/// The number of warp scan steps
|
| 95 |
+
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
|
| 96 |
+
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
//---------------------------------------------------------------------
|
| 100 |
+
// Thread fields
|
| 101 |
+
//---------------------------------------------------------------------
|
| 102 |
+
|
| 103 |
+
/// Lane index in logical warp
|
| 104 |
+
unsigned int lane_id;
|
| 105 |
+
|
| 106 |
+
/// Logical warp index in 32-thread physical warp
|
| 107 |
+
unsigned int warp_id;
|
| 108 |
+
|
| 109 |
+
/// 32-thread physical warp member mask of logical warp
|
| 110 |
+
unsigned int member_mask;
|
| 111 |
+
|
| 112 |
+
//---------------------------------------------------------------------
|
| 113 |
+
// Construction
|
| 114 |
+
//---------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
/// Constructor
|
| 117 |
+
explicit __device__ __forceinline__
|
| 118 |
+
WarpReverseScan()
|
| 119 |
+
: lane_id(cub::LaneId())
|
| 120 |
+
, warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
|
| 121 |
+
// , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
|
| 122 |
+
, member_mask(WarpMask<LOGICAL_WARP_THREADS>(warp_id))
|
| 123 |
+
{
|
| 124 |
+
if (!IS_ARCH_WARP) {
|
| 125 |
+
lane_id = lane_id % LOGICAL_WARP_THREADS;
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
/// Broadcast
|
| 131 |
+
__device__ __forceinline__ T Broadcast(
|
| 132 |
+
T input, ///< [in] The value to broadcast
|
| 133 |
+
int src_lane) ///< [in] Which warp lane is to do the broadcasting
|
| 134 |
+
{
|
| 135 |
+
return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
/// Inclusive scan
|
| 140 |
+
template <typename ScanOpT>
|
| 141 |
+
__device__ __forceinline__ void InclusiveReverseScan(
|
| 142 |
+
T input, ///< [in] Calling thread's input item.
|
| 143 |
+
T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
| 144 |
+
ScanOpT scan_op) ///< [in] Binary scan operator
|
| 145 |
+
{
|
| 146 |
+
inclusive_output = input;
|
| 147 |
+
#pragma unroll
|
| 148 |
+
for (int STEP = 0; STEP < STEPS; STEP++) {
|
| 149 |
+
int offset = 1 << STEP;
|
| 150 |
+
T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
| 151 |
+
inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
|
| 152 |
+
);
|
| 153 |
+
// Perform scan op if from a valid peer
|
| 154 |
+
inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
|
| 155 |
+
? inclusive_output : scan_op(temp, inclusive_output);
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Exclusive scan
|
| 160 |
+
// Get exclusive from inclusive
|
| 161 |
+
template <typename ScanOpT>
|
| 162 |
+
__device__ __forceinline__ void ExclusiveReverseScan(
|
| 163 |
+
T input, ///< [in] Calling thread's input item.
|
| 164 |
+
T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
| 165 |
+
ScanOpT scan_op, ///< [in] Binary scan operator
|
| 166 |
+
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
|
| 167 |
+
{
|
| 168 |
+
T inclusive_output;
|
| 169 |
+
InclusiveReverseScan(input, inclusive_output, scan_op);
|
| 170 |
+
warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
|
| 171 |
+
// initial value unknown
|
| 172 |
+
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
| 173 |
+
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
| 174 |
+
);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
/**
|
| 178 |
+
* \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
|
| 179 |
+
*/
|
| 180 |
+
template <typename ScanOpT>
|
| 181 |
+
__device__ __forceinline__ void ReverseScan(
|
| 182 |
+
T input, ///< [in] Calling thread's input item.
|
| 183 |
+
T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
|
| 184 |
+
T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
|
| 185 |
+
ScanOpT scan_op) ///< [in] Binary scan operator
|
| 186 |
+
{
|
| 187 |
+
InclusiveReverseScan(input, inclusive_output, scan_op);
|
| 188 |
+
// initial value unknown
|
| 189 |
+
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
| 190 |
+
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
| 191 |
+
);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
};
|
| 195 |
+
|
| 196 |
+
/**
|
| 197 |
+
* \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
|
| 198 |
+
*/
|
| 199 |
+
template <
|
| 200 |
+
typename T, ///< Data type being scanned
|
| 201 |
+
int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
|
| 202 |
+
bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
|
| 203 |
+
>
|
| 204 |
+
struct BlockReverseScan {
|
| 205 |
+
//---------------------------------------------------------------------
|
| 206 |
+
// Types and constants
|
| 207 |
+
//---------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
/// Constants
|
| 210 |
+
/// The thread block size in threads
|
| 211 |
+
static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
|
| 212 |
+
|
| 213 |
+
/// Layout type for padded thread block raking grid
|
| 214 |
+
using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
|
| 215 |
+
// The number of reduction elements is not a multiple of the number of raking threads for now
|
| 216 |
+
static_assert(BlockRakingLayout::UNGUARDED);
|
| 217 |
+
|
| 218 |
+
/// Number of raking threads
|
| 219 |
+
static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
|
| 220 |
+
/// Number of raking elements per warp synchronous raking thread
|
| 221 |
+
static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
|
| 222 |
+
/// Cooperative work can be entirely warp synchronous
|
| 223 |
+
static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
|
| 224 |
+
|
| 225 |
+
/// WarpReverseScan utility type
|
| 226 |
+
using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
|
| 227 |
+
|
| 228 |
+
/// Shared memory storage layout type
|
| 229 |
+
struct _TempStorage {
|
| 230 |
+
typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
|
| 231 |
+
};
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
/// Alias wrapper allowing storage to be unioned
|
| 235 |
+
struct TempStorage : cub::Uninitialized<_TempStorage> {};
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
//---------------------------------------------------------------------
|
| 239 |
+
// Per-thread fields
|
| 240 |
+
//---------------------------------------------------------------------
|
| 241 |
+
|
| 242 |
+
// Thread fields
|
| 243 |
+
_TempStorage &temp_storage;
|
| 244 |
+
unsigned int linear_tid;
|
| 245 |
+
T cached_segment[SEGMENT_LENGTH];
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
//---------------------------------------------------------------------
|
| 249 |
+
// Utility methods
|
| 250 |
+
//---------------------------------------------------------------------
|
| 251 |
+
|
| 252 |
+
/// Performs upsweep raking reduction, returning the aggregate
|
| 253 |
+
template <typename ScanOp>
|
| 254 |
+
__device__ __forceinline__ T Upsweep(ScanOp scan_op) {
|
| 255 |
+
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
| 256 |
+
// Read data into registers
|
| 257 |
+
#pragma unroll
|
| 258 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
| 259 |
+
T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
|
| 260 |
+
#pragma unroll
|
| 261 |
+
for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
|
| 262 |
+
raking_partial = scan_op(raking_partial, cached_segment[i]);
|
| 263 |
+
}
|
| 264 |
+
return raking_partial;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
/// Performs exclusive downsweep raking scan
|
| 269 |
+
template <typename ScanOp>
|
| 270 |
+
__device__ __forceinline__ void ExclusiveDownsweep(
|
| 271 |
+
ScanOp scan_op,
|
| 272 |
+
T raking_partial)
|
| 273 |
+
{
|
| 274 |
+
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
| 275 |
+
// Read data back into registers
|
| 276 |
+
if (!MEMOIZE) {
|
| 277 |
+
#pragma unroll
|
| 278 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
| 279 |
+
}
|
| 280 |
+
ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
|
| 281 |
+
// Write data back to smem
|
| 282 |
+
#pragma unroll
|
| 283 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
//---------------------------------------------------------------------
|
| 288 |
+
// Constructors
|
| 289 |
+
//---------------------------------------------------------------------
|
| 290 |
+
|
| 291 |
+
/// Constructor
|
| 292 |
+
__device__ __forceinline__ BlockReverseScan(
|
| 293 |
+
TempStorage &temp_storage)
|
| 294 |
+
:
|
| 295 |
+
temp_storage(temp_storage.Alias()),
|
| 296 |
+
linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
|
| 297 |
+
{}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
/// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
| 301 |
+
template <
|
| 302 |
+
typename ScanOp,
|
| 303 |
+
typename BlockPostfixCallbackOp>
|
| 304 |
+
__device__ __forceinline__ void ExclusiveReverseScan(
|
| 305 |
+
T input, ///< [in] Calling thread's input item
|
| 306 |
+
T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
|
| 307 |
+
ScanOp scan_op, ///< [in] Binary scan operator
|
| 308 |
+
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
|
| 309 |
+
{
|
| 310 |
+
if (WARP_SYNCHRONOUS) {
|
| 311 |
+
// Short-circuit directly to warp-synchronous scan
|
| 312 |
+
T block_aggregate;
|
| 313 |
+
WarpReverseScan warp_scan;
|
| 314 |
+
warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
|
| 315 |
+
// Obtain warp-wide postfix in lane0, then broadcast to other lanes
|
| 316 |
+
T block_postfix = block_postfix_callback_op(block_aggregate);
|
| 317 |
+
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
| 318 |
+
exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
|
| 319 |
+
} else {
|
| 320 |
+
// Place thread partial into shared memory raking grid
|
| 321 |
+
T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
|
| 322 |
+
detail::uninitialized_copy(placement_ptr, input);
|
| 323 |
+
cub::CTA_SYNC();
|
| 324 |
+
// Reduce parallelism down to just raking threads
|
| 325 |
+
if (linear_tid < RAKING_THREADS) {
|
| 326 |
+
WarpReverseScan warp_scan;
|
| 327 |
+
// Raking upsweep reduction across shared partials
|
| 328 |
+
T upsweep_partial = Upsweep(scan_op);
|
| 329 |
+
// Warp-synchronous scan
|
| 330 |
+
T exclusive_partial, block_aggregate;
|
| 331 |
+
warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
|
| 332 |
+
// Obtain block-wide postfix in lane0, then broadcast to other lanes
|
| 333 |
+
T block_postfix = block_postfix_callback_op(block_aggregate);
|
| 334 |
+
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
| 335 |
+
// Update postfix with warpscan exclusive partial
|
| 336 |
+
T downsweep_postfix = linear_tid == RAKING_THREADS - 1
|
| 337 |
+
? block_postfix : scan_op(block_postfix, exclusive_partial);
|
| 338 |
+
// Exclusive raking downsweep scan
|
| 339 |
+
ExclusiveDownsweep(scan_op, downsweep_postfix);
|
| 340 |
+
}
|
| 341 |
+
cub::CTA_SYNC();
|
| 342 |
+
// Grab thread postfix from shared memory
|
| 343 |
+
exclusive_output = *placement_ptr;
|
| 344 |
+
|
| 345 |
+
// // Compute warp scan in each warp.
|
| 346 |
+
// // The exclusive output from the last lane in each warp is invalid.
|
| 347 |
+
// T inclusive_output;
|
| 348 |
+
// WarpReverseScan warp_scan;
|
| 349 |
+
// warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
|
| 350 |
+
|
| 351 |
+
// // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
|
| 352 |
+
// T block_aggregate;
|
| 353 |
+
// T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
|
| 354 |
+
|
| 355 |
+
// // Apply warp postfix to our lane's partial
|
| 356 |
+
// if (warp_id != 0) {
|
| 357 |
+
// exclusive_output = scan_op(warp_postfix, exclusive_output);
|
| 358 |
+
// if (lane_id == 0) { exclusive_output = warp_postfix; }
|
| 359 |
+
// }
|
| 360 |
+
|
| 361 |
+
// // Use the first warp to determine the thread block postfix, returning the result in lane0
|
| 362 |
+
// if (warp_id == 0) {
|
| 363 |
+
// T block_postfix = block_postfix_callback_op(block_aggregate);
|
| 364 |
+
// if (lane_id == 0) {
|
| 365 |
+
// // Share the postfix with all threads
|
| 366 |
+
// detail::uninitialized_copy(&temp_storage.block_postfix,
|
| 367 |
+
// block_postfix);
|
| 368 |
+
|
| 369 |
+
// exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
|
| 370 |
+
// }
|
| 371 |
+
// }
|
| 372 |
+
|
| 373 |
+
// cub::CTA_SYNC();
|
| 374 |
+
|
| 375 |
+
// // Incorporate thread block postfix into outputs
|
| 376 |
+
// T block_postfix = temp_storage.block_postfix;
|
| 377 |
+
// if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
/**
|
| 383 |
+
* \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
| 384 |
+
*/
|
| 385 |
+
template <
|
| 386 |
+
int ITEMS_PER_THREAD,
|
| 387 |
+
typename ScanOp,
|
| 388 |
+
typename BlockPostfixCallbackOp>
|
| 389 |
+
__device__ __forceinline__ void InclusiveReverseScan(
|
| 390 |
+
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
|
| 391 |
+
T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
|
| 392 |
+
ScanOp scan_op, ///< [in] Binary scan functor
|
| 393 |
+
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
|
| 394 |
+
{
|
| 395 |
+
// Reduce consecutive thread items in registers
|
| 396 |
+
T thread_postfix = ThreadReverseReduce(input, scan_op);
|
| 397 |
+
// Exclusive thread block-scan
|
| 398 |
+
ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
|
| 399 |
+
// Inclusive scan in registers with postfix as seed
|
| 400 |
+
ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
};
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
|
| 9 |
+
struct SSMScanParamsBase {
|
| 10 |
+
using index_t = uint32_t;
|
| 11 |
+
|
| 12 |
+
int batch, seqlen, n_chunks;
|
| 13 |
+
index_t a_batch_stride;
|
| 14 |
+
index_t b_batch_stride;
|
| 15 |
+
index_t out_batch_stride;
|
| 16 |
+
|
| 17 |
+
// Common data pointers.
|
| 18 |
+
void *__restrict__ a_ptr;
|
| 19 |
+
void *__restrict__ b_ptr;
|
| 20 |
+
void *__restrict__ out_ptr;
|
| 21 |
+
void *__restrict__ x_ptr;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 25 |
+
|
| 26 |
+
struct SSMParamsBase {
|
| 27 |
+
using index_t = uint32_t;
|
| 28 |
+
|
| 29 |
+
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
| 30 |
+
int dim_ngroups_ratio;
|
| 31 |
+
|
| 32 |
+
bool delta_softplus;
|
| 33 |
+
|
| 34 |
+
index_t A_d_stride;
|
| 35 |
+
index_t A_dstate_stride;
|
| 36 |
+
index_t B_batch_stride;
|
| 37 |
+
index_t B_d_stride;
|
| 38 |
+
index_t B_dstate_stride;
|
| 39 |
+
index_t B_group_stride;
|
| 40 |
+
index_t C_batch_stride;
|
| 41 |
+
index_t C_d_stride;
|
| 42 |
+
index_t C_dstate_stride;
|
| 43 |
+
index_t C_group_stride;
|
| 44 |
+
index_t u_batch_stride;
|
| 45 |
+
index_t u_d_stride;
|
| 46 |
+
index_t delta_batch_stride;
|
| 47 |
+
index_t delta_d_stride;
|
| 48 |
+
index_t out_batch_stride;
|
| 49 |
+
index_t out_d_stride;
|
| 50 |
+
|
| 51 |
+
// Common data pointers.
|
| 52 |
+
void *__restrict__ A_ptr;
|
| 53 |
+
void *__restrict__ B_ptr;
|
| 54 |
+
void *__restrict__ C_ptr;
|
| 55 |
+
void *__restrict__ D_ptr;
|
| 56 |
+
void *__restrict__ u_ptr;
|
| 57 |
+
void *__restrict__ delta_ptr;
|
| 58 |
+
void *__restrict__ delta_bias_ptr;
|
| 59 |
+
void *__restrict__ out_ptr;
|
| 60 |
+
void *__restrict__ x_ptr;
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
struct SSMParamsBwd: public SSMParamsBase {
|
| 64 |
+
index_t dout_batch_stride;
|
| 65 |
+
index_t dout_d_stride;
|
| 66 |
+
index_t dA_d_stride;
|
| 67 |
+
index_t dA_dstate_stride;
|
| 68 |
+
index_t dB_batch_stride;
|
| 69 |
+
index_t dB_group_stride;
|
| 70 |
+
index_t dB_d_stride;
|
| 71 |
+
index_t dB_dstate_stride;
|
| 72 |
+
index_t dC_batch_stride;
|
| 73 |
+
index_t dC_group_stride;
|
| 74 |
+
index_t dC_d_stride;
|
| 75 |
+
index_t dC_dstate_stride;
|
| 76 |
+
index_t du_batch_stride;
|
| 77 |
+
index_t du_d_stride;
|
| 78 |
+
index_t ddelta_batch_stride;
|
| 79 |
+
index_t ddelta_d_stride;
|
| 80 |
+
|
| 81 |
+
// Common data pointers.
|
| 82 |
+
void *__restrict__ dout_ptr;
|
| 83 |
+
void *__restrict__ dA_ptr;
|
| 84 |
+
void *__restrict__ dB_ptr;
|
| 85 |
+
void *__restrict__ dC_ptr;
|
| 86 |
+
void *__restrict__ dD_ptr;
|
| 87 |
+
void *__restrict__ du_ptr;
|
| 88 |
+
void *__restrict__ ddelta_ptr;
|
| 89 |
+
void *__restrict__ ddelta_bias_ptr;
|
| 90 |
+
};
|
rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <cuda_bf16.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <c10/util/complex.h> // For scalar_value_type
|
| 10 |
+
|
| 11 |
+
#define MAX_DSTATE 256
|
| 12 |
+
|
| 13 |
+
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
| 14 |
+
return {a.x + b.x, a.y + b.y};
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
| 18 |
+
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
| 22 |
+
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 26 |
+
|
| 27 |
+
template<int BYTES> struct BytesToType {};
|
| 28 |
+
|
| 29 |
+
template<> struct BytesToType<16> {
|
| 30 |
+
using Type = uint4;
|
| 31 |
+
static_assert(sizeof(Type) == 16);
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
template<> struct BytesToType<8> {
|
| 35 |
+
using Type = uint64_t;
|
| 36 |
+
static_assert(sizeof(Type) == 8);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
template<> struct BytesToType<4> {
|
| 40 |
+
using Type = uint32_t;
|
| 41 |
+
static_assert(sizeof(Type) == 4);
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template<> struct BytesToType<2> {
|
| 45 |
+
using Type = uint16_t;
|
| 46 |
+
static_assert(sizeof(Type) == 2);
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template<> struct BytesToType<1> {
|
| 50 |
+
using Type = uint8_t;
|
| 51 |
+
static_assert(sizeof(Type) == 1);
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
template<typename scalar_t, int N>
|
| 57 |
+
struct Converter{
|
| 58 |
+
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
| 59 |
+
#pragma unroll
|
| 60 |
+
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
| 61 |
+
}
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template<int N>
|
| 65 |
+
struct Converter<at::Half, N>{
|
| 66 |
+
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
| 67 |
+
static_assert(N % 2 == 0);
|
| 68 |
+
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
| 69 |
+
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
| 70 |
+
#pragma unroll
|
| 71 |
+
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
| 72 |
+
}
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
#if __CUDA_ARCH__ >= 800
|
| 76 |
+
template<int N>
|
| 77 |
+
struct Converter<at::BFloat16, N>{
|
| 78 |
+
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
| 79 |
+
static_assert(N % 2 == 0);
|
| 80 |
+
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
| 81 |
+
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
#endif
|
| 87 |
+
|
| 88 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 89 |
+
template<typename scalar_t> struct SSMScanOp;
|
| 90 |
+
|
| 91 |
+
template<>
|
| 92 |
+
struct SSMScanOp<float> {
|
| 93 |
+
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
| 94 |
+
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
| 95 |
+
}
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
// A stateful callback functor that maintains a running prefix to be applied
|
| 99 |
+
// during consecutive scan operations.
|
| 100 |
+
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
| 101 |
+
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
| 102 |
+
scan_t running_prefix;
|
| 103 |
+
// Constructor
|
| 104 |
+
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
| 105 |
+
// Callback operator to be entered by the first warp of threads in the block.
|
| 106 |
+
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
| 107 |
+
__device__ scan_t operator()(scan_t block_aggregate) {
|
| 108 |
+
scan_t old_prefix = running_prefix;
|
| 109 |
+
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
| 110 |
+
return old_prefix;
|
| 111 |
+
}
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 115 |
+
|
| 116 |
+
template<typename Ktraits>
|
| 117 |
+
inline __device__ void load_input(typename Ktraits::input_t *u,
|
| 118 |
+
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
| 119 |
+
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
| 120 |
+
int seqlen) {
|
| 121 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
| 122 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
| 123 |
+
using vec_t = typename Ktraits::vec_t;
|
| 124 |
+
Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
| 125 |
+
reinterpret_cast<vec_t*>(u),
|
| 126 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
| 127 |
+
);
|
| 128 |
+
} else {
|
| 129 |
+
Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template<typename Ktraits>
|
| 134 |
+
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
| 135 |
+
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
| 136 |
+
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
| 137 |
+
int seqlen) {
|
| 138 |
+
constexpr int kNItems = Ktraits::kNItems;
|
| 139 |
+
typename Ktraits::input_t B_vals_load[kNItems];
|
| 140 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
| 141 |
+
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
| 142 |
+
using vec_t = typename Ktraits::vec_t;
|
| 143 |
+
Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
| 144 |
+
reinterpret_cast<vec_t*>(Bvar),
|
| 145 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
| 146 |
+
);
|
| 147 |
+
} else {
|
| 148 |
+
Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
| 149 |
+
}
|
| 150 |
+
// #pragma unroll
|
| 151 |
+
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
| 152 |
+
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template<typename Ktraits>
|
| 156 |
+
inline __device__ void store_output(typename Ktraits::input_t *out,
|
| 157 |
+
const float (&out_vals)[Ktraits::kNItems],
|
| 158 |
+
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
| 159 |
+
int seqlen) {
|
| 160 |
+
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
| 163 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
| 164 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
| 165 |
+
using vec_t = typename Ktraits::vec_t;
|
| 166 |
+
Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
| 167 |
+
reinterpret_cast<vec_t*>(out),
|
| 168 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
| 169 |
+
);
|
| 170 |
+
} else {
|
| 171 |
+
Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template<typename Ktraits>
|
| 176 |
+
inline __device__ void store_output1(typename Ktraits::output_t *out,
|
| 177 |
+
const float (&out_vals)[Ktraits::kNItems],
|
| 178 |
+
typename Ktraits::BlockStoreOutputT::TempStorage &smem_store,
|
| 179 |
+
int seqlen) {
|
| 180 |
+
typename Ktraits::output_t write_vals[Ktraits::kNItems];
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
| 183 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
| 184 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreOutputVecT::TempStorage&>(smem_store);
|
| 185 |
+
using vec_t = typename Ktraits::vec_t;
|
| 186 |
+
Ktraits::BlockStoreOutputVecT(smem_store_vec).Store(
|
| 187 |
+
reinterpret_cast<vec_t*>(out),
|
| 188 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoadsOutput]>(write_vals)
|
| 189 |
+
);
|
| 190 |
+
} else {
|
| 191 |
+
Ktraits::BlockStoreOutputT(smem_store).Store(out, write_vals, seqlen);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
template<typename Ktraits>
|
| 196 |
+
inline __device__ void load_output(typename Ktraits::output_t *u,
|
| 197 |
+
typename Ktraits::output_t (&u_vals)[Ktraits::kNItems],
|
| 198 |
+
typename Ktraits::BlockLoadOutputT::TempStorage &smem_load,
|
| 199 |
+
int seqlen) {
|
| 200 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
| 201 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadOutputVecT::TempStorage&>(smem_load);
|
| 202 |
+
using vec_t = typename Ktraits::vec_t;
|
| 203 |
+
Ktraits::BlockLoadOutputVecT(smem_load_vec).Load(
|
| 204 |
+
reinterpret_cast<vec_t*>(u),
|
| 205 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoadsOutput]>(u_vals)
|
| 206 |
+
);
|
| 207 |
+
} else {
|
| 208 |
+
Ktraits::BlockLoadOutputT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
| 209 |
+
}
|
| 210 |
+
}
|