diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..20014bf069b5e32e2daa23df48ef3ae46fe4ab68 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps filter=lfs diff=lfs merge=lfs -text +rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o filter=lfs diff=lfs merge=lfs -text +rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o filter=lfs diff=lfs merge=lfs -text +rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o filter=lfs diff=lfs merge=lfs -text diff --git a/rscd/models/backbones/lib_mamba/__init__.py b/rscd/models/backbones/lib_mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03b12948957f4217f432ea5aa0861a7c8c091f13 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/__init__.py @@ -0,0 +1,58 @@ +import os +from functools import partial +import torch + +from .vmamba import VSSM +from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref + + +def build_vssm_model(config, **kwargs): + model_type = config.MODEL.TYPE + if model_type in ["vssm"]: + model = VSSM( + patch_size=config.MODEL.VSSM.PATCH_SIZE, + in_chans=config.MODEL.VSSM.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + depths=config.MODEL.VSSM.DEPTHS, + dims=config.MODEL.VSSM.EMBED_DIM, + # =================== + ssm_d_state=config.MODEL.VSSM.SSM_D_STATE, + ssm_ratio=config.MODEL.VSSM.SSM_RATIO, + ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO, + ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)), + ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER, + ssm_conv=config.MODEL.VSSM.SSM_CONV, + ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS, + ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE, + ssm_init=config.MODEL.VSSM.SSM_INIT, + forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE, + # =================== + mlp_ratio=config.MODEL.VSSM.MLP_RATIO, + mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER, + mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE, + # =================== + drop_path_rate=config.MODEL.DROP_PATH_RATE, + patch_norm=config.MODEL.VSSM.PATCH_NORM, + norm_layer=config.MODEL.VSSM.NORM_LAYER, + downsample_version=config.MODEL.VSSM.DOWNSAMPLE, + patchembed_version=config.MODEL.VSSM.PATCHEMBED, + gmlp=config.MODEL.VSSM.GMLP, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + # =================== + posembed=config.MODEL.VSSM.POSEMBED, + imgsize=config.DATA.IMG_SIZE, + ) + return model + + return None + + +def build_model(config, is_pretrain=False): + model = None + if model is None: + model = build_vssm_model(config, is_pretrain) + return model + + + + diff --git a/rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc b/rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86d217990cee78b2d55db843c884adb3785149e0 Binary files /dev/null and b/rscd/models/backbones/lib_mamba/__pycache__/__init__.cpython-38.pyc differ diff --git a/rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc b/rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f07a544b32b9b28f5c320d2e4ee5f00b1444acd2 Binary files /dev/null and b/rscd/models/backbones/lib_mamba/__pycache__/csm_triton.cpython-38.pyc differ diff --git a/rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc b/rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e15a5ae74a8fc4b6bbac8d7022a013b9cdbb8be Binary files /dev/null and b/rscd/models/backbones/lib_mamba/__pycache__/csm_tritonk2.cpython-38.pyc differ diff --git a/rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc b/rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..278936041af50a0624f96a95a45be51bdf5318dd Binary files /dev/null and b/rscd/models/backbones/lib_mamba/__pycache__/csms6s.cpython-38.pyc differ diff --git a/rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc b/rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8247e80f7722aa91e2ee75feaf03946665fc19fc Binary files /dev/null and b/rscd/models/backbones/lib_mamba/__pycache__/vmamba.cpython-38.pyc differ diff --git a/rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc b/rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73602e415c543bf26ef2e7ee72ad7f1ab55b107e Binary files /dev/null and b/rscd/models/backbones/lib_mamba/__pycache__/vmambanew.cpython-38.pyc differ diff --git a/rscd/models/backbones/lib_mamba/csm_triton.py b/rscd/models/backbones/lib_mamba/csm_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..b91eac68d1a7837f58f91698046af733a2a2fcb8 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/csm_triton.py @@ -0,0 +1,644 @@ +import torch +import warnings + +WITH_TRITON = True +# WITH_TRITON = False +try: + import triton + import triton.language as tl +except: + WITH_TRITON = False + warnings.warn("Triton not installed, fall back to pytorch implements.") + +# to make sure cached_property can be loaded for triton +if WITH_TRITON: + try: + from functools import cached_property + except: + warnings.warn("if you are using py37, add this line to functools.py: " + "cached_property = lambda func: property(lru_cache()(func))") + +# torch implementation ======================================== +def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): + if in_channel_first: + B, C, H, W = x.shape + if scans == 0: + y = x.new_empty((B, 4, C, H * W)) + y[:, 0, :, :] = x.flatten(2, 3) + y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) + y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) + elif scans == 1: + y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) + elif scans == 2: + y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) + y = torch.cat([y, y.flip(dims=[-1])], dim=1) + else: + B, H, W, C = x.shape + if scans == 0: + y = x.new_empty((B, H * W, 4, C)) + y[:, :, 0, :] = x.flatten(1, 2) + y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) + y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) + elif scans == 1: + y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1) + elif scans == 2: + y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) + y = torch.cat([y, y.flip(dims=[1])], dim=2) + + if in_channel_first and (not out_channel_first): + y = y.permute(0, 3, 1, 2).contiguous() + elif (not in_channel_first) and out_channel_first: + y = y.permute(0, 2, 3, 1).contiguous() + + return y + + +def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): + if out_channel_first: + B, K, D, H, W = y.shape + y = y.view(B, K, D, -1) + if scans == 0: + y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) + y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) + elif scans == 1: + y = y.sum(1) + elif scans == 2: + y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) + y = y.sum(1) + else: + B, H, W, K, D = y.shape + y = y.view(B, -1, K, D) + if scans == 0: + y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) + y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) + elif scans == 1: + y = y.sum(2) + elif scans == 2: + y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) + y = y.sum(2) + + if in_channel_first and (not out_channel_first): + y = y.permute(0, 2, 1).contiguous() + elif (not in_channel_first) and out_channel_first: + y = y.permute(0, 2, 1).contiguous() + + return y + + +def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): + if in_channel_first: + B, _, C, H, W = x.shape + if scans == 0: + y = torch.stack([ + x[:, 0].flatten(2, 3), + x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), + torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), + torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), + ], dim=1) + elif scans == 1: + y = x.flatten(2, 3) + elif scans == 2: + y = torch.stack([ + x[:, 0].flatten(2, 3), + x[:, 1].flatten(2, 3), + torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), + torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), + ], dim=1) + else: + B, H, W, _, C = x.shape + if scans == 0: + y = torch.stack([ + x[:, :, :, 0].flatten(1, 2), + x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), + torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), + torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), + ], dim=2) + elif scans == 1: + y = x.flatten(1, 2) + elif scans == 2: + y = torch.stack([ + x[:, 0].flatten(1, 2), + x[:, 1].flatten(1, 2), + torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), + torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), + ], dim=2) + + if in_channel_first and (not out_channel_first): + y = y.permute(0, 3, 1, 2).contiguous() + elif (not in_channel_first) and out_channel_first: + y = y.permute(0, 2, 3, 1).contiguous() + + return y + + +def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): + if out_channel_first: + B, K, D, H, W = y.shape + y = y.view(B, K, D, -1) + if scans == 0: + y = torch.stack([ + y[:, 0], + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), + torch.flip(y[:, 2], dims=[-1]), + torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), + ], dim=1) + elif scans == 1: + y = y + elif scans == 2: + y = torch.stack([ + y[:, 0], + y[:, 1], + torch.flip(y[:, 2], dims=[-1]), + torch.flip(y[:, 3], dims=[-1]), + ], dim=1) + else: + B, H, W, _, D = y.shape + y = y.view(B, -1, K, D) + if scans == 0: + y = torch.stack([ + y[:, :, 0], + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), + torch.flip(y[:, :, 2], dims=[1]), + torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), + ], dim=2) + elif scans == 1: + y = y + elif scans == 2: + y = torch.stack([ + y[:, :, 0], + y[:, :, 1], + torch.flip(y[:, :, 2], dims=[1]), + torch.flip(y[:, :, 3], dims=[1]), + ], dim=2) + + if out_channel_first and (not in_channel_first): + y = y.permute(0, 3, 1, 2).contiguous() + elif (not out_channel_first) and in_channel_first: + y = y.permute(0, 2, 3, 1).contiguous() + + return y + + +class CrossScanF(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + # y: (B, 4, C, H * W) | (B, H * W, 4, C) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + + if one_by_one: + B, K, C, H, W = x.shape + if not in_channel_first: + B, H, W, K, C = x.shape + else: + B, C, H, W = x.shape + if not in_channel_first: + B, H, W, C = x.shape + ctx.shape = (B, C, H, W) + + _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd + y = _fn(x, in_channel_first, out_channel_first, scans) + + return y + + @staticmethod + def backward(ctx, ys: torch.Tensor): + # out: (b, k, d, l) + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + + ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) + _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd + y = _fn(ys, in_channel_first, out_channel_first, scans) + + if one_by_one: + y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1) + else: + y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) + + return y, None, None, None, None + + +class CrossMergeF(torch.autograd.Function): + @staticmethod + def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + # y: (B, 4, C, H * W) | (B, H * W, 4, C) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + + B, K, C, H, W = ys.shape + if not out_channel_first: + B, H, W, K, C = ys.shape + ctx.shape = (B, C, H, W) + + _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd + y = _fn(ys, in_channel_first, out_channel_first, scans) + + return y + + @staticmethod + def backward(ctx, x: torch.Tensor): + # B, D, L = x.shape + # out: (b, k, d, h, w) + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + + if not one_by_one: + if in_channel_first: + x = x.view(B, C, H, W) + else: + x = x.view(B, H, W, C) + else: + if in_channel_first: + x = x.view(B, 4, C, H, W) + else: + x = x.view(B, H, W, 4, C) + + _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd + x = _fn(x, in_channel_first, out_channel_first, scans) + x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C) + + return x, None, None, None, None + + +# triton implements ======================================== + +@triton.jit +def triton_cross_scan_flex( + x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + y, # (B, 4, C, H, W) | (B, H, W, 4, C) + x_layout: tl.constexpr, + y_layout: tl.constexpr, + operation: tl.constexpr, + onebyone: tl.constexpr, + scans: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + BW: tl.constexpr, + DC: tl.constexpr, + DH: tl.constexpr, + DW: tl.constexpr, + NH: tl.constexpr, + NW: tl.constexpr, +): + # x_layout = 0 + # y_layout = 1 # 0 BCHW, 1 BHWC + # operation = 0 # 0 scan, 1 merge + # onebyone = 0 # 0 false, 1 true + # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional + + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_w = (i_hw // NW), (i_hw % NW) + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] + _for_C = min(DC - i_c * BC, BC) + + HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] + HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans + HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip + HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip + + if scans == 1: + HWRoute1 = HWRoute0 + HWRoute2 = HWRoute0 + HWRoute3 = HWRoute0 + elif scans == 2: + HWRoute1 = HWRoute0 + HWRoute3 = HWRoute2 + + _tmp1 = DC * DH * DW + + y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) + if y_layout == 0: + p_y1 = y_ptr_base + HWRoute0 + p_y2 = y_ptr_base + _tmp1 + HWRoute1 + p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 + p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 + else: + p_y1 = y_ptr_base + HWRoute0 * 4 * DC + p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC + p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC + p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC + + if onebyone == 0: + x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x = x_ptr_base + HWRoute0 + else: + p_x = x_ptr_base + HWRoute0 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x = tl.load(p_x + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) + elif operation == 1: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) + _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) + tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) + + else: + x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x1 = x_ptr_base + HWRoute0 + p_x2 = p_x1 + _tmp1 + p_x3 = p_x2 + _tmp1 + p_x4 = p_x3 + _tmp1 + else: + p_x1 = x_ptr_base + HWRoute0 * 4 * DC + p_x2 = p_x1 + DC + p_x3 = p_x2 + DC + p_x4 = p_x3 + DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) + else: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) + tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) + tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) + tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) + + +class CrossScanTritonF(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + if one_by_one: + if in_channel_first: + B, _, C, H, W = x.shape + else: + B, H, W, _, C = x.shape + else: + if in_channel_first: + B, C, H, W = x.shape + else: + B, H, W, C = x.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + + y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y + + @staticmethod + def backward(ctx, y: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + if one_by_one: + x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) + else: + x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) + + triton_cross_scan_flex[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x, None, None, None, None + + +class CrossMergeTritonF(torch.autograd.Function): + @staticmethod + def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): + if out_channel_first: + B, _, C, H, W = y.shape + else: + B, H, W, _, C = y.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + if one_by_one: + x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) + else: + x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x + + @staticmethod + def backward(ctx, x: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) + triton_cross_scan_flex[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y, None, None, None, None, None + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): + # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + # y: (B, 4, C, L) | (B, L, 4, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF + return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): + # y: (B, 4, C, L) | (B, L, 4, C) + # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF + return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) + + +# checks ================================================================= + +class CHECK: + def check_csm_triton(): + B, C, H, W = 2, 192, 56, 57 + dtype=torch.float16 + dtype=torch.float32 + x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True) + y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True) + x1 = x.clone().detach().requires_grad_(True) + y1 = y.clone().detach().requires_grad_(True) + + def cross_scan(x: torch.Tensor): + B, C, H, W = x.shape + L = H * W + xs = torch.stack([ + x.view(B, C, L), + torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), + torch.flip(x.contiguous().view(B, C, L), dims=[-1]), + torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]), + ], dim=1).view(B, 4, C, L) + return xs + + def cross_merge(out_y: torch.Tensor): + B, K, D, H, W = out_y.shape + L = H * W + out_y = out_y.view(B, K, D, L) + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y + return y + + def cross_scan_1b1(x: torch.Tensor): + B, K, C, H, W = x.shape + L = H * W + xs = torch.stack([ + x[:, 0].view(B, C, L), + torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L), + torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]), + torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]), + ], dim=1).view(B, 4, C, L) + return xs + + def unidi_scan(x): + B, C, H, W = x.shape + x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) + return x + + def unidi_merge(ys): + B, K, C, H, W = ys.shape + return ys.view(B, 4, -1, H * W).sum(1) + + def bidi_scan(x): + B, C, H, W = x.shape + x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) + x = torch.cat([x, x.flip(dims=[-1])], dim=1) + return x + + def bidi_merge(ys): + B, K, D, H, W = ys.shape + ys = ys.view(B, K, D, -1) + ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) + return ys.contiguous().sum(1) + + if True: + res0 = triton.testing.do_bench(lambda :cross_scan(x)) + res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False)) + # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x)) + res3 = triton.testing.do_bench(lambda :cross_merge(y)) + res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False)) + # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y)) + # print(res0, res1, res2, res3, res4, res5) + print(res0, res1, res3, res4) + res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward()) + res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward()) + # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward()) + res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward()) + res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward()) + # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward()) + # print(res0, res1, res2, res3, res4, res5) + print(res0, res1, res3, res4) + + print("test cross scan") + for (cs0, cm0, cs1, cm1) in [ + # channel_first -> channel_first + (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn), + (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)), + (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)), + + # flex: BLC->BCL; BCL->BLC; BLC->BLC; + (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)), + (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)), + (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)), + + # previous + # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)), + # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)), + # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)), + ]: + x.grad, x1.grad, y.grad, y1.grad = None, None, None, None + o0 = cs0(x) + o1 = cs1(x1) + o0.backward(y.view(B, 4, C, H * W)) + o1.backward(y.view(B, 4, C, H * W)) + print((o0 - o1).abs().max()) + print((x.grad - x1.grad).abs().max()) + o0 = cm0(y) + o1 = cm1(y1) + o0.backward(x.view(B, C, H * W)) + o1.backward(x.view(B, C, H * W)) + print((o0 - o1).abs().max()) + print((y.grad - y1.grad).abs().max()) + x.grad, x1.grad, y.grad, y1.grad = None, None, None, None + print("===============", flush=True) + + print("test cross scan one by one") + for (cs0, cs1) in [ + (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)), + # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)), + ]: + o0 = cs0(y) + o1 = cs1(y1) + o0.backward(y.view(B, 4, C, H * W)) + o1.backward(y.view(B, 4, C, H * W)) + print((o0 - o1).abs().max()) + print((y.grad - y1.grad).abs().max()) + x.grad, x1.grad, y.grad, y1.grad = None, None, None, None + print("===============", flush=True) + + +if __name__ == "__main__": + CHECK.check_csm_triton() + + + + diff --git a/rscd/models/backbones/lib_mamba/csm_tritonk2.py b/rscd/models/backbones/lib_mamba/csm_tritonk2.py new file mode 100644 index 0000000000000000000000000000000000000000..19e530e22da0ab0877eafadd0dfd56ed06ec32d3 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/csm_tritonk2.py @@ -0,0 +1,899 @@ +import torch +import warnings +import os +os.environ["TRITON_INTERPRET"] = "1" + +WITH_TRITON = True +# WITH_TRITON = False +try: + import triton + import triton.language as tl +except: + WITH_TRITON = False + warnings.warn("Triton not installed, fall back to pytorch implements.") + +# to make sure cached_property can be loaded for triton +if WITH_TRITON: + try: + from functools import cached_property + except: + warnings.warn("if you are using py37, add this line to functools.py: " + "cached_property = lambda func: property(lru_cache()(func))") + +# torch implementation ======================================== +def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2): + if in_channel_first: + B, C, H, W = x.shape + if scans == 0: + y = x.new_empty((B, 4, C, H * W)) + y[:, 0, :, :] = x.flatten(2, 3) + y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) + y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) + elif scans == 1: + y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) + elif scans == 2: + y = x.view(B, 1, C, H * W) + y = torch.cat([y, y.flip(dims=[-1])], dim=1) + else: + B, H, W, C = x.shape + if scans == 0: + y = x.new_empty((B, H * W, 4, C)) + y[:, :, 0, :] = x.flatten(1, 2) + y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) + y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) + elif scans == 1: + y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) + elif scans == 2: + y = x.view(B, H * W, 1, C) + y = torch.cat([y, y.flip(dims=[1])], dim=2) + + if in_channel_first and (not out_channel_first): + y = y.permute(0, 3, 1, 2).contiguous() + elif (not in_channel_first) and out_channel_first: + y = y.permute(0, 2, 3, 1).contiguous() + + return y + + +def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2): + if out_channel_first: + B, K, D, H, W = y.shape + y = y.view(B, K, D, -1) + if scans == 0: + y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) + y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) + elif scans == 1: + y = y.sum(1) + elif scans == 2: + y = y[:, 0] + y[:, 1].flip(dims=[-1]).view(B, 1, D, -1) + y = y.sum(1) + else: + B, H, W, K, D = y.shape + y = y.view(B, -1, K, D) + if scans == 0: + y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) + y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) + elif scans == 1: + y = y.sum(2) + elif scans == 2: + y = y[:, :, 0] + y[:, :, 1].flip(dims=[1]).view(B, -1, 1, D) + y = y.sum(2) + + if in_channel_first and (not out_channel_first): + y = y.permute(0, 2, 1).contiguous() + elif (not in_channel_first) and out_channel_first: + y = y.permute(0, 2, 1).contiguous() + + return y + + +def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2): + if in_channel_first: + B, _, C, H, W = x.shape + if scans == 0: + y = torch.stack([ + x[:, 0].flatten(2, 3), + x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), + torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), + torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), + ], dim=1) + elif scans == 1: + y = x.flatten(2, 3) + elif scans == 2: + y = torch.stack([ + x[:, 0].flatten(2, 3), + x[:, 1].flatten(2, 3), + torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), + torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), + ], dim=1) + else: + B, H, W, _, C = x.shape + if scans == 0: + y = torch.stack([ + x[:, :, :, 0].flatten(1, 2), + x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), + torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), + torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), + ], dim=2) + elif scans == 1: + y = x.flatten(1, 2) + elif scans == 2: + y = torch.stack([ + x[:, 0].flatten(1, 2), + x[:, 1].flatten(1, 2), + torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), + torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), + ], dim=2) + + if in_channel_first and (not out_channel_first): + y = y.permute(0, 3, 1, 2).contiguous() + elif (not in_channel_first) and out_channel_first: + y = y.permute(0, 2, 3, 1).contiguous() + + return y + + +def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=2): + if out_channel_first: + B, K, D, H, W = y.shape + y = y.view(B, K, D, -1) + if scans == 0: + y = torch.stack([ + y[:, 0], + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), + torch.flip(y[:, 2], dims=[-1]), + torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), + ], dim=1) + elif scans == 1: + y = y + elif scans == 2: + y = torch.stack([ + y[:, 0], + y[:, 1], + torch.flip(y[:, 2], dims=[-1]), + torch.flip(y[:, 3], dims=[-1]), + ], dim=1) + else: + B, H, W, _, D = y.shape + y = y.view(B, -1, 2, D) + if scans == 0: + y = torch.stack([ + y[:, :, 0], + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), + torch.flip(y[:, :, 2], dims=[1]), + torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), + ], dim=2) + elif scans == 1: + y = y + elif scans == 2: + y = torch.stack([ + y[:, :, 0], + y[:, :, 1], + torch.flip(y[:, :, 2], dims=[1]), + torch.flip(y[:, :, 3], dims=[1]), + ], dim=2) + + if out_channel_first and (not in_channel_first): + y = y.permute(0, 3, 1, 2).contiguous() + elif (not out_channel_first) and in_channel_first: + y = y.permute(0, 2, 3, 1).contiguous() + + return y + +class CrossScan(torch.nn.Module): + def __init__(self, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2): + super(CrossScan, self).__init__() + self.in_channel_first = in_channel_first + self.out_channel_first = out_channel_first + self.one_by_one = one_by_one + self.scans = scans + + def forward(self, x: torch.Tensor): + if self.one_by_one: + B, K, C, H, W = x.shape + if not self.in_channel_first: + B, H, W, K, C = x.shape + else: + B, C, H, W = x.shape + if not self.in_channel_first: + B, H, W, C = x.shape + self.shape = (B, C, H, W) + + _fn = cross_scan1b1_fwd if self.one_by_one else cross_scan_fwd + y = _fn(x, self.in_channel_first, self.out_channel_first, self.scans) + + return y + + def backward(self, ys: torch.Tensor): + B, C, H, W = self.shape + + ys = ys.view(B, -1, C, H, W) if self.out_channel_first else ys.view(B, H, W, -1, C) + _fn = cross_merge1b1_fwd if self.one_by_one else cross_merge_fwd + y = _fn(ys, self.in_channel_first, self.out_channel_first, self.scans) + + if self.one_by_one: + y = y.view(B, 2, -1, H, W) if self.in_channel_first else y.view(B, H, W, 2, -1) + else: + y = y.view(B, -1, H, W) if self.in_channel_first else y.view(B, H, W, -1) + + return y + + +class CrossMerge(torch.nn.Module): + def __init__(self, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2): + super(CrossMerge, self).__init__() + self.in_channel_first = in_channel_first + self.out_channel_first = out_channel_first + self.one_by_one = one_by_one + self.scans = scans + + def forward(self, ys: torch.Tensor): + B, K, C, H, W = ys.shape + if not self.out_channel_first: + B, H, W, K, C = ys.shape + self.shape = (B, C, H, W) + + _fn = cross_merge1b1_fwd if self.one_by_one else cross_merge_fwd + y = _fn(ys, self.in_channel_first, self.out_channel_first, self.scans) + + return y + + def backward(self, x: torch.Tensor): + B, C, H, W = self.shape + + if not self.one_by_one: + if self.in_channel_first: + x = x.view(B, C, H, W) + else: + x = x.view(B, H, W, C) + else: + if self.in_channel_first: + x = x.view(B, 2, C, H, W) + else: + x = x.view(B, H, W, 2, C) + + _fn = cross_scan1b1_fwd if self.one_by_one else cross_scan_fwd + x = _fn(x, self.in_channel_first, self.out_channel_first, self.scans) + x = x.view(B, 2, C, H, W) if self.out_channel_first else x.view(B, H, W, 2, C) + + return x +class CrossScanF(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2): + # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 2, C) + # y: (B, 2, C, H * W) | (B, H * W, 2, C) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + + if one_by_one: + B, K, C, H, W = x.shape + if not in_channel_first: + B, H, W, K, C = x.shape + else: + B, C, H, W = x.shape + if not in_channel_first: + B, H, W, C = x.shape + ctx.shape = (B, C, H, W) + + _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd + y = _fn(x, in_channel_first, out_channel_first, scans) + + return y + + @staticmethod + def backward(ctx, ys: torch.Tensor): + # out: (b, k, d, l) + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + + ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) + _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd + y = _fn(ys, in_channel_first, out_channel_first, scans) + + if one_by_one: + y = y.view(B, 2, -1, H, W) if in_channel_first else y.view(B, H, W, 2, -1) + else: + y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) + + return y, None, None, None, None + + +class CrossMergeF(torch.autograd.Function): + @staticmethod + def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2): + # x: (B, C, H, W) | (B, H, W, C) | (B, 2, C, H, W) | (B, H, W, 2, C) + # y: (B, 2, C, H * W) | (B, H * W, 4, C) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + + B, K, C, H, W = ys.shape + if not out_channel_first: + B, H, W, K, C = ys.shape + ctx.shape = (B, C, H, W) + + _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd + y = _fn(ys, in_channel_first, out_channel_first, scans) + + return y + + @staticmethod + def backward(ctx, x: torch.Tensor): + # B, D, L = x.shape + # out: (b, k, d, h, w) + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + + if not one_by_one: + if in_channel_first: + x = x.view(B, C, H, W) + else: + x = x.view(B, H, W, C) + else: + if in_channel_first: + x = x.view(B, 2, C, H, W) + else: + x = x.view(B, H, W, 2, C) + + _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd + x = _fn(x, in_channel_first, out_channel_first, scans) + x = x.view(B, 2, C, H, W) if out_channel_first else x.view(B, H, W, 2, C) + + return x, None, None, None, None + + +# triton implements ======================================== + +@triton.jit +def triton_cross_scan_flex_k2( + x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + y, # (B, 4, C, H, W) | (B, H, W, 4, C) + x_layout: tl.constexpr, + y_layout: tl.constexpr, + operation: tl.constexpr, + onebyone: tl.constexpr, + scans: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + BW: tl.constexpr, + DC: tl.constexpr, + DH: tl.constexpr, + DW: tl.constexpr, + NH: tl.constexpr, + NW: tl.constexpr, +): + # x_layout = 0 + # y_layout = 1 # 0 BCHW, 1 BHWC + # operation = 0 # 0 scan, 1 merge + # onebyone = 0 # 0 false, 1 true + # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional + + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_w = (i_hw // NW), (i_hw % NW) + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] + _for_C = min(DC - i_c * BC, BC) + + HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] + # HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] # trans + HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip + # HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip + + if scans == 1: + HWRoute2 = HWRoute0 + + + _tmp1 = DC * DH * DW + + y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) + if y_layout == 0: + p_y1 = y_ptr_base + HWRoute0 + # p_y2 = y_ptr_base + _tmp1 + HWRoute1 + p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 + # p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 + else: + p_y1 = y_ptr_base + HWRoute0 * 4 * DC + # p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC + p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC + # p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC + + if onebyone == 0: + x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x = x_ptr_base + HWRoute0 + else: + p_x = x_ptr_base + HWRoute0 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x = tl.load(p_x + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) + # tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) + # tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) + elif operation == 1: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + # _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) + # _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) + # tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) + tl.store(p_x + _idx_x, _y1 + _y3, mask=_mask_hw) + + + else: + x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x1 = x_ptr_base + HWRoute0 + p_x2 = p_x1 + _tmp1 + p_x3 = p_x2 + _tmp1 + p_x4 = p_x3 + _tmp1 + else: + p_x1 = x_ptr_base + HWRoute0 * 4 * DC + p_x2 = p_x1 + DC + p_x3 = p_x2 + DC + p_x4 = p_x3 + DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) + # tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) + # tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) + else: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) + # tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) + tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) + # tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) + +@triton.jit +def triton_cross_scan_flex_k2( + x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + y, # (B, 4, C, H, W) | (B, H, W, 4, C) + x_layout: tl.constexpr, + y_layout: tl.constexpr, + operation: tl.constexpr, + onebyone: tl.constexpr, + scans: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + BW: tl.constexpr, + DC: tl.constexpr, + DH: tl.constexpr, + DW: tl.constexpr, + NH: tl.constexpr, + NW: tl.constexpr, +): + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_w = (i_hw // NW), (i_hw % NW) + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] + _for_C = min(DC - i_c * BC, BC) + + HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] + HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip + + if scans == 1: + HWRoute2 = HWRoute0 + + _tmp1 = DC * DH * DW + + y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) + if y_layout == 0: + p_y1 = y_ptr_base + HWRoute0 + p_y2 = y_ptr_base + 2 * _tmp1 + HWRoute2 + else: + p_y1 = y_ptr_base + HWRoute0 * 4 * DC + p_y2 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC + + if onebyone == 0: + x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x = x_ptr_base + HWRoute0 + else: + p_x = x_ptr_base + HWRoute0 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x = tl.load(p_x + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) + elif operation == 1: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + tl.store(p_x + _idx_x, _y1 + _y2, mask=_mask_hw) + + else: + x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x1 = x_ptr_base + HWRoute0 + p_x2 = p_x1 + 2 * _tmp1 + else: + p_x1 = x_ptr_base + HWRoute0 * 4 * DC + p_x2 = p_x1 + 2 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) + tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) + else: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) + tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) + +@triton.jit +def triton_cross_scan_flex_k2( + x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) + y, # (B, 4, C, H, W) | (B, H, W, 4, C) + x_layout: tl.constexpr, + y_layout: tl.constexpr, + operation: tl.constexpr, + onebyone: tl.constexpr, + scans: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + BW: tl.constexpr, + DC: tl.constexpr, + DH: tl.constexpr, + DW: tl.constexpr, + NH: tl.constexpr, + NW: tl.constexpr, +): + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_w = (i_hw // NW), (i_hw % NW) + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] + _for_C = min(DC - i_c * BC, BC) + + HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] + HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip + + if scans == 1: + HWRoute2 = HWRoute0 + + _tmp1 = DC * DH * DW + + y_ptr_base = y + i_b * 2 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) + if y_layout == 0: + p_y1 = y_ptr_base + HWRoute0 + p_y2 = y_ptr_base + 1 * _tmp1 + HWRoute2 + else: + p_y1 = y_ptr_base + HWRoute0 * 4 * DC + p_y2 = y_ptr_base + 1 * DC + HWRoute2 * 4 * DC + + if onebyone == 0: + x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x = x_ptr_base + HWRoute0 + else: + p_x = x_ptr_base + HWRoute0 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x = tl.load(p_x + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) + tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) + elif operation == 1: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + tl.store(p_x + _idx_x, _y1 + _y2, mask=_mask_hw) + + else: + x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) + if x_layout == 0: + p_x1 = x_ptr_base + HWRoute0 + p_x2 = p_x1 + 2 * _tmp1 + else: + p_x1 = x_ptr_base + HWRoute0 * 4 * DC + p_x2 = p_x1 + 2 * DC + + if operation == 0: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _x1 = tl.load(p_x1 + _idx_x, mask=_mask_hw) + _x2 = tl.load(p_x2 + _idx_x, mask=_mask_hw) + tl.store(p_y1 + _idx_y, _x1, mask=_mask_hw) + tl.store(p_y2 + _idx_y, _x2, mask=_mask_hw) + else: + for idxc in range(_for_C): + _idx_x = idxc * DH * DW if x_layout == 0 else idxc + _idx_y = idxc * DH * DW if y_layout == 0 else idxc + _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) + _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) + tl.store(p_x1 + _idx_x, _y1, mask=_mask_hw) + tl.store(p_x2 + _idx_x, _y2, mask=_mask_hw) + +class CrossScanTritonFk2(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2): + if one_by_one: + if in_channel_first: + B, _, C, H, W = x.shape + else: + B, H, W, _, C = x.shape + else: + if in_channel_first: + B, C, H, W = x.shape + else: + B, H, W, C = x.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + + y = x.new_empty((B, 2, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 2, C)) + triton_cross_scan_flex_k2[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y + + @staticmethod + def backward(ctx, y: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + if one_by_one: + x = y.new_empty((B, 2, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 2, C)) + else: + x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) + + triton_cross_scan_flex_k2[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x, None, None, None, None + + +class CrossMergeTritonFk2(torch.autograd.Function): + @staticmethod + def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2): + if out_channel_first: + B, _, C, H, W = y.shape + else: + B, H, W, _, C = y.shape + B, C, H, W = int(B), int(C), int(H), int(W) + BC, BH, BW = 1, 32, 32 + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + ctx.in_channel_first = in_channel_first + ctx.out_channel_first = out_channel_first + ctx.one_by_one = one_by_one + ctx.scans = scans + ctx.shape = (B, C, H, W) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + if one_by_one: + x = y.new_empty((B, 2, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 2, C)) + else: + x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) + triton_cross_scan_flex_k2[(NH * NW, NC, B)]( + x, y.contiguous(), + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return x + + @staticmethod + def backward(ctx, x: torch.Tensor): + in_channel_first = ctx.in_channel_first + out_channel_first = ctx.out_channel_first + one_by_one = ctx.one_by_one + scans = ctx.scans + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + y = x.new_empty((B, 2, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 2, C)) + triton_cross_scan_flex_k2[(NH * NW, NC, B)]( + x.contiguous(), y, + (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, + BC, BH, BW, C, H, W, NH, NW + ) + return y, None, None, None, None, None + + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_scan_fn_k2(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False): + # x: (B, C, H, W) | (B, H, W, C) | (B, 2, C, H, W) | (B, H, W, 2, C) + # y: (B, 2, C, L) | (B, L, 2, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + CSF = CrossScanTritonFk2 if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF + return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) + +# @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) +def cross_merge_fn_k2(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False): + # y: (B, 2, C, L) | (B, L, 2, C) + # x: (B, C, H * W) | (B, H * W, C) | (B, 2, C, H * W) | (B, H * W, 2, C) + # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; + CMF = CrossMergeTritonFk2 if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF + return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) + +def cross_scan_fn_k2_torch(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False): + cross_scan = CrossScan(in_channel_first, out_channel_first, one_by_one, scans) + return cross_scan(x) + +def cross_merge_fn_k2_torch(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=2, force_torch=False): + cross_merge = CrossMerge(in_channel_first, out_channel_first, one_by_one, scans) + return cross_merge(y) + +# checks ================================================================= + +class CHECK: + def check_csm_triton(): + B, C, H, W = 2, 192, 56, 57 + dtype=torch.float16 + dtype=torch.float32 + x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True) + y = torch.randn((B, 2, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True) + x1 = x.clone().detach().requires_grad_(True) + y1 = y.clone().detach().requires_grad_(True) + + def cross_scan(x: torch.Tensor): + B, C, H, W = x.shape + L = H * W + xs = torch.stack([ + x.view(B, C, L), + torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), + torch.flip(x.contiguous().view(B, C, L), dims=[-1]), + torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]), + ], dim=1).view(B, 4, C, L) + return xs + + def cross_merge(out_y: torch.Tensor): + B, K, D, H, W = out_y.shape + L = H * W + out_y = out_y.view(B, K, D, L) + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y + return y + + def cross_scan_1b1(x: torch.Tensor): + B, K, C, H, W = x.shape + L = H * W + xs = torch.stack([ + x[:, 0].view(B, C, L), + torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L), + torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]), + torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]), + ], dim=1).view(B, 2, C, L) + return xs + + def unidi_scan(x): + B, C, H, W = x.shape + x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) + return x + + def unidi_merge(ys): + B, K, C, H, W = ys.shape + return ys.view(B, 4, -1, H * W).sum(1) + + def bidi_scan(x): + B, C, H, W = x.shape + x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) + x = torch.cat([x, x.flip(dims=[-1])], dim=1) + return x + + def bidi_merge(ys): + B, K, D, H, W = ys.shape + ys = ys.view(B, K, D, -1) + ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) + return ys.contiguous().sum(1) + + if True: + # res0 = triton.testing.do_bench(lambda :cross_scan(x)) + res1 = triton.testing.do_bench(lambda :cross_scan_fn_k2(x, True, True, False)) + # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x)) + # res3 = triton.testing.do_bench(lambda :cross_merge(y)) + res4 = triton.testing.do_bench(lambda :cross_merge_fn_k2(y, True, True, False)) + # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y)) + # print(res0, res1, res2, res3, res4, res5) + print(res0, res1, res3, res4) + res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward()) + res1 = triton.testing.do_bench(lambda :cross_scan_fn_k2(x, True, True, False).sum().backward()) + # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward()) + res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward()) + res4 = triton.testing.do_bench(lambda :cross_merge_fn_k2(y, True, True, False).sum().backward()) + # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward()) + # print(res0, res1, res2, res3, res4, res5) + print(res0, res1, res3, res4) + + print("test cross scan") + for (cs0, cm0, cs1, cm1) in [ + # channel_first -> channel_first + (cross_scan, cross_merge, cross_scan_fn_k2, cross_merge_fn_k2), + (unidi_scan, unidi_merge, lambda x: cross_scan_fn_k2(x, scans=1), lambda x: cross_merge_fn_k2(x, scans=1)), + (bidi_scan, bidi_merge, lambda x: cross_scan_fn_k2(x, scans=2), lambda x: cross_merge_fn_k2(x, scans=2)), + + # flex: BLC->BCL; BCL->BLC; BLC->BLC; + (cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn_k2(x, in_channel_first=False).permute(0, 2, 1)), + (cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn_k2(x.permute(0, 3, 4, 1, 2), out_channel_first=False)), + (cross_scan, cross_merge, lambda x: cross_scan_fn_k2(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn_k2(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)), + + # previous + # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)), + # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)), + # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)), + ]: + x.grad, x1.grad, y.grad, y1.grad = None, None, None, None + o0 = cs0(x) + o1 = cs1(x1) + o0.backward(y.view(B, 2, C, H * W)) + o1.backward(y.view(B, 2, C, H * W)) + print((o0 - o1).abs().max()) + print((x.grad - x1.grad).abs().max()) + o0 = cm0(y) + o1 = cm1(y1) + o0.backward(x.view(B, C, H * W)) + o1.backward(x.view(B, C, H * W)) + print((o0 - o1).abs().max()) + print((y.grad - y1.grad).abs().max()) + x.grad, x1.grad, y.grad, y1.grad = None, None, None, None + print("===============", flush=True) + + print("test cross scan one by one") + for (cs0, cs1) in [ + (cross_scan_1b1, lambda x: cross_scan_fn_k2(x, one_by_one=True)), + # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)), + ]: + o0 = cs0(y) + o1 = cs1(y1) + o0.backward(y.view(B, 2, C, H * W)) + o1.backward(y.view(B, 2, C, H * W)) + print((o0 - o1).abs().max()) + print((y.grad - y1.grad).abs().max()) + x.grad, x1.grad, y.grad, y1.grad = None, None, None, None + print("===============", flush=True) + + +if __name__ == "__main__": + CHECK.check_csm_triton() + + + + diff --git a/rscd/models/backbones/lib_mamba/csms6s.py b/rscd/models/backbones/lib_mamba/csms6s.py new file mode 100644 index 0000000000000000000000000000000000000000..67945f2469836b34b0d2f02e63e7ad3cb9870d29 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/csms6s.py @@ -0,0 +1,266 @@ +import time +import torch +import warnings + + +WITH_SELECTIVESCAN_OFLEX = True +WITH_SELECTIVESCAN_CORE = False +WITH_SELECTIVESCAN_MAMBA = True +try: + import selective_scan_cuda_oflex +except ImportError: + WITH_SELECTIVESCAN_OFLEX = False + warnings.warn("Can not import selective_scan_cuda_oflex. This affects speed.") + print("Can not import selective_scan_cuda_oflex. This affects speed.", flush=True) +try: + import selective_scan_cuda_core +except ImportError: + WITH_SELECTIVESCAN_CORE = False +try: + import selective_scan_cuda +except ImportError: + WITH_SELECTIVESCAN_MAMBA = False + + +def selective_scan_torch( + u: torch.Tensor, # (B, K * C, L) + delta: torch.Tensor, # (B, K * C, L) + A: torch.Tensor, # (K * C, N) + B: torch.Tensor, # (B, K, N, L) + C: torch.Tensor, # (B, K, N, L) + D: torch.Tensor = None, # (K * C) + delta_bias: torch.Tensor = None, # (K * C) + delta_softplus=True, + oflex=True, + *args, + **kwargs +): + dtype_in = u.dtype + Batch, K, N, L = B.shape + KCdim = u.shape[1] + Cdim = int(KCdim / K) + assert u.shape == (Batch, KCdim, L) + assert delta.shape == (Batch, KCdim, L) + assert A.shape == (KCdim, N) + assert C.shape == B.shape + + if delta_bias is not None: + delta = delta + delta_bias[..., None] + if delta_softplus: + delta = torch.nn.functional.softplus(delta) + + u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float() + B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) + C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + + if True: + x = A.new_zeros((Batch, KCdim, N)) + ys = [] + for i in range(L): + x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :] + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + ys.append(y) + y = torch.stack(ys, dim=2) # (B, C, L) + + out = y if D is None else y + u * D.unsqueeze(-1) + return out if oflex else out.to(dtype=dtype_in) + + +class SelectiveScanCuda(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None): + ctx.delta_softplus = delta_softplus + backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend + backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend + backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend + ctx.backend = backend + if backend == "oflex": + out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) + elif backend == "core": + out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) + elif backend == "mamba": + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dout, *args): + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + backend = ctx.backend + if dout.stride(-1) != 1: + dout = dout.contiguous() + if backend == "oflex": + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 + ) + elif backend == "core": + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 + ) + elif backend == "mamba": + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, + False + ) + return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None + + +def selective_scan_fn( + u: torch.Tensor, # (B, K * C, L) + delta: torch.Tensor, # (B, K * C, L) + A: torch.Tensor, # (K * C, N) + B: torch.Tensor, # (B, K, N, L) + C: torch.Tensor, # (B, K, N, L) + D: torch.Tensor = None, # (K * C) + delta_bias: torch.Tensor = None, # (K * C) + delta_softplus=True, + oflex=True, + backend=None, +): + WITH_CUDA = (WITH_SELECTIVESCAN_OFLEX or WITH_SELECTIVESCAN_CORE or WITH_SELECTIVESCAN_MAMBA) + fn = selective_scan_torch if backend == "torch" or (not WITH_CUDA) else SelectiveScanCuda.apply + return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend) + + +# fvcore flops ======================================= +def print_jit_input_names(inputs): + print("input params: ", end=" ", flush=True) + try: + for i in range(10): + print(inputs[i].debugName(), end=" ", flush=True) + except Exception as e: + pass + print("", flush=True) + +def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): + """ + u: r(B D L) + delta: r(B D L) + A: r(D N) + B: r(B N L) + C: r(B N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + ignores: + [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] + """ + assert not with_complex + # https://github.com/state-spaces/mamba/issues/110 + flops = 9 * B * L * D * N + if with_D: + flops += B * D * L + if with_Z: + flops += B * D * L + return flops + +# this is only for selective_scan_ref... +def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): + """ + u: r(B D L) + delta: r(B D L) + A: r(D N) + B: r(B N L) + C: r(B N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + ignores: + [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] + """ + import numpy as np + + # fvcore.nn.jit_handles + def get_flops_einsum(input_shapes, equation): + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + # divided by 2 because we count MAC (multiply-add counted as one flop) + flop = float(np.floor(float(line.split(":")[-1]) / 2)) + return flop + + + assert not with_complex + + flops = 0 # below code flops = 0 + + flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") + if with_Group: + flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") + else: + flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") + + in_for_flops = B * D * N + if with_Group: + in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") + else: + in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") + flops += L * in_for_flops + if with_D: + flops += B * D * L + if with_Z: + flops += B * D * L + return flops + +def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True): + if verbose: + print_jit_input_names(inputs) + flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn + B, D, L = inputs[0].type().sizes() + N = inputs[2].type().sizes()[1] + flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) + return flops + + +if __name__ == "__main__": + def params(B, K, C, N, L, device = torch.device("cuda"), itype = torch.float): + As = (-0.5 * torch.rand(K * C, N, device=device, dtype=torch.float32)).requires_grad_() + Bs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_() + Cs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_() + Ds = torch.randn((K * C), device=device, dtype=torch.float32).requires_grad_() + u = torch.randn((B, K * C, L), device=device, dtype=itype).requires_grad_() + delta = (0.5 * torch.rand((B, K * C, L), device=device, dtype=itype)).requires_grad_() + delta_bias = (0.5 * torch.rand((K * C), device=device, dtype=torch.float32)).requires_grad_() + return u, delta, As, Bs, Cs, Ds, delta_bias + + def bench(func, xs, Warmup=30, NTimes=20): + import time + torch.cuda.synchronize() + for r in range(Warmup): + for x in xs: + func(x) + torch.cuda.synchronize() + tim0 = time.time() + for r in range(NTimes): + for x in xs: + func(x) + torch.cuda.synchronize() + return (time.time() - tim0) / NTimes + + def check(): + u, delta, As, Bs, Cs, Ds, delta_bias = params(1, 4, 16, 8, 512, itype=torch.float16) + u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1 = [x.clone().detach().requires_grad_() for x in [u, delta, As, Bs, Cs, Ds, delta_bias]] + + # out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="torch") + out = selective_scan_fn(u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1, True, backend="oflex") + out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="mamba") + print((out_ref - out).abs().max()) + out.sum().backward() + out_ref.sum().backward() + for x, y in zip([u, As, Bs, Cs, Ds, delta, delta_bias], [u1, As1, Bs1, Cs1, Ds1, delta1, delta_bias1]): + print((x.grad - y.grad).abs().max()) + + u, delta, As, Bs, Cs, Ds, delta_bias = params(128, 4, 96, 8, 56 * 56) + print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="oflex"), [(u, delta, As, Bs, Cs, Ds, delta_bias),])) + print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="mamba"), [(u, delta, As, Bs, Cs, Ds, delta_bias),])) + print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="torch"), [(u, delta, As, Bs, Cs, Ds, delta_bias),])) + + check() + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md b/rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..89d104660abd27478cb28f59c4b66176ee8b0b54 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/README.md @@ -0,0 +1,97 @@ +# mamba-mini +An efficient implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba. + +### mathematical derivation +![image](../assets/derivation.png) + +### code +```python +import torch +def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + """ + # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen + us: B, G * D, L + dts: B, G * D, L + As: G * D, N + Bs: B, G, N, L + Cs: B, G, N, L + Ds: G * D + delta_bias: G * D + # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small + """ + def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix): + """ + partial(h) / partial(t) = Ah + Bu; y = Ch + Du; + => partial(h*exp(-At)) / partial(t) = Bu*exp(-At); + => h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv}; + => h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i}); + y_i = C_i*h_i + D*u_i + """ + """ + us, dts: (L, B, G, D) # L is chunk_size + As: (G, D, N) + Bs, Cs: (L, B, G, N) + Ds: (G, D) + hprefix: (B, G, D, N) + """ + ts = dts.cumsum(dim=0) + Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() + scale = Ats[-1].detach() + rAts = Ats / scale + duts = dts * us + dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs) + hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0) + hs = hs_tmp + Ats * hprefix.unsqueeze(0) + ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs) + return ys, hs + + inp_dtype = us.dtype + has_D = Ds is not None + + dts = dts.float() + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).float() + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + B, G, N, L = Bs.shape + us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float() + dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float() + As = As.view(G, -1, N).float() + Bs = Bs.permute(3, 0, 1, 2).float() + Cs = Cs.permute(3, 0, 1, 2).float() + Ds = Ds.view(G, -1).float() if has_D else None + D = As.shape[1] + + oys = [] + # ohs = [] + hprefix = us.new_zeros((B, G, D, N), dtype=torch.float) + for i in range(0, L - 1, chunksize): + ys, hs = selective_scan_chunk( + us[i:i + chunksize], dts[i:i + chunksize], + As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix, + ) + oys.append(ys) + # ohs.append(hs) + hprefix = hs[-1] + + oys = torch.cat(oys, dim=0) + # ohs = torch.cat(ohs, dim=0) + if has_D: + oys = oys + Ds * us + oys = oys.permute(1, 2, 3, 0).view(B, -1, L) + oys = oys.to(inp_dtype) + # hprefix = hprefix.to(inp_dtype) + + return oys if not return_last_state else (oys, hprefix.view(B, G * D, N)) + +``` + +### to test +```bash +pytest test_selective_scan.py +``` diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..37f1b1782f84ba2fd42216cfe102c8f293f01d2e --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/lib.linux-x86_64-3.8/selective_scan_cuda_oflex.cpython-38-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fcf524eeaf71e641653c1aff1f8fac591a1e6916300d322830bd02476873ab1 +size 34969816 diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps new file mode 100644 index 0000000000000000000000000000000000000000..1a1b0b75eaf3dca3f6c34ced86b429113c5b5448 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_deps @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e61c23bbef4f0b8f414187d9149e6e0818ce400c3351405d494b774b988bf6d +size 501136 diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log new file mode 100644 index 0000000000000000000000000000000000000000..c2f56a914b854bfa27b1e70a5bb5ef41a3f68221 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/.ninja_log @@ -0,0 +1,4 @@ +# ninja log v5 +7 17272 1748140810026258300 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o ab3bac6bd7b8268f +8 23832 1748140816810431800 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o 7f9a77b388057fc6 +7 57431 1748140850419474900 /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o 3cffffbdd6b9fec1 diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja new file mode 100644 index 0000000000000000000000000000000000000000..11cfa7eb64912534eda862f4f3eed5fdeb32f641 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/build.ninja @@ -0,0 +1,35 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda-11.8/bin/nvcc + +cflags = -pthread -B /root/anaconda3/envs/rscd/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.8/include -I/root/anaconda3/envs/rscd/include/python3.8 -c +post_cflags = -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=selective_scan_cuda_oflex -D_GLIBCXX_USE_CXX11_ABI=0 +cuda_cflags = -I/mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/rscd/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.8/include -I/root/anaconda3/envs/rscd/include/python3.8 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math --ptxas-options=-v -lineinfo -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=selective_scan_cuda_oflex -D_GLIBCXX_USE_CXX11_ABI=0 +cuda_dlink_post_cflags = +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + + + + + +build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o: cuda_compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu +build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o: cuda_compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu +build /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o: compile /mnt/d/WORK/rschange-main/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp + + + + + + + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o new file mode 100644 index 0000000000000000000000000000000000000000..9c41f214460f3149fb96089d70668efa55177f22 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_bwd.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a18749c332876fbeddf0356bbb0d4c979fffd89df7f4a27796e2dd39b523f2e +size 12294744 diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o new file mode 100644 index 0000000000000000000000000000000000000000..bc494f76f0f7a62c703e888570ba7f8713a86fd3 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_core_fwd.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:642208da032fd86584c184e13e9ed9d18a9c6e85d925770f7ce034e0d22774a9 +size 13211880 diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o new file mode 100644 index 0000000000000000000000000000000000000000..68b4562d494430f57e8758bb44c148a662c4b200 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/build/temp.linux-x86_64-3.8/csrc/selective_scan/cusoflex/selective_scan_oflex.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0867e500c6352b2c0e1938ea0c8e6825aafbabc5699ec41d25a7793c56ed5d1e +size 14839600 diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9f56704fea0e7b98bed6f333526641de27484d5a --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cub_extra.cuh @@ -0,0 +1,50 @@ +// WarpMask is copied from /usr/local/cuda-12.1/include/cub/util_ptx.cuh +// PowerOfTwo is copied from /usr/local/cuda-12.1/include/cub/util_type.cuh + +#pragma once + +#include +#include +#include +#include + +/** + * \brief Statically determine if N is a power-of-two + */ + template + struct PowerOfTwo + { + enum { VALUE = ((N & (N - 1)) == 0) }; + }; + + +/** + * @brief Returns the warp mask for a warp of @p LOGICAL_WARP_THREADS threads + * + * @par + * If the number of threads assigned to the virtual warp is not a power of two, + * it's assumed that only one virtual warp exists. + * + * @tparam LOGICAL_WARP_THREADS [optional] The number of threads per + * "logical" warp (may be less than the number of + * hardware warp threads). + * @param warp_id Id of virtual warp within architectural warp + */ + template + __host__ __device__ __forceinline__ + unsigned int WarpMask(unsigned int warp_id) + { + constexpr bool is_pow_of_two = PowerOfTwo::VALUE; + constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0); + + unsigned int member_mask = 0xFFFFFFFFu >> + (CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS); + + if (is_pow_of_two && !is_arch_warp) + { + member_mask <<= warp_id * LOGICAL_WARP_THREADS; + } + + return member_mask; + } + \ No newline at end of file diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d8a8e6c69343ab3ac9e6b740fffb425224fc424 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan.cpp @@ -0,0 +1,354 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" +#define MAX_DSTATE 256 + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +using weight_t = float; + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + params.B_dstate_stride = B.stride(2); + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + params.C_dstate_stride = C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + void* dD_ptr, + void* ddelta_bias_ptr, + bool delta_softplus) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, dout, + D_ptr, delta_bias_ptr, x_ptr, delta_softplus); + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + params.dB_dstate_stride = dB.stride(2); + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + params.dC_dstate_stride = dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + +} + +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = B.size(1); + + TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); + TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream); + }); + std::vector result = {out, x}; + return result; +} + +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = B.size(1); + + TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); + TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor out; + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..35c2c0c504b823adc1ae6a862a7c8a070cf7830c --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh @@ -0,0 +1,306 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int MaxDState = MAX_DSTATE; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && 3; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + float dD_val = 0; + float ddelta_bias_val = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize; + Cvar += (params.n_chunks - 1) * kChunkSize; + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + constexpr float kLog2e = M_LOG2E; + weight_t A_val = A[state_idx * params.A_dstate_stride]; + weight_t A_scaled = A_val * kLog2e; + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + auto &smem_load_weight_C = smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState: threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * C_vals[i]; + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); + dC_vals[i] = dout_vals[i] * thread_data[i].y; + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + auto &smem_exchange_C = smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNItems]; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + Bvar -= kChunkSize; + Cvar -= kChunkSize; + } + + if (params.dD_ptr != nullptr) { + __syncthreads(); + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + __syncthreads(); + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..aa105781b26803c5529f6ce16a2ca96bdec9a9ee --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..75cdbf7cb1a21fea73667981885ab1bb85c35651 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9a36ea4fdc4ab2b3aa027b05db8a793d28b14927 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh @@ -0,0 +1,203 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int MaxDState = MAX_DSTATE; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate; + + float D_val = 0; // attention! + if (params.D_ptr != nullptr) { + D_val = reinterpret_cast(params.D_ptr)[dim_id]; + } + float delta_bias = 0; + if (params.delta_bias_ptr != nullptr) { + delta_bias = reinterpret_cast(params.delta_bias_ptr)[dim_id]; + } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNItems], delta_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if (params.delta_softplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + delta_u_vals[i] = delta_vals[i] * u_val; + out_vals[i] = D_val * u_val; + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + constexpr float kLog2e = M_LOG2E; + weight_t A_val = A[state_idx * params.A_dstate_stride]; + A_val *= kLog2e; + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight1, (params.seqlen - chunk * kChunkSize)); + __syncthreads(); + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[chunk * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + out_vals[i] += thread_data[i].y * C_vals[i]; + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize); + Bvar += kChunkSize; + Cvar += kChunkSize; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6041a26da357493f2d30aca9e13595a33625fc15 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh @@ -0,0 +1,302 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan_ndstate.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && 3; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + 1); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_d_stride; + weight_t A_val = reinterpret_cast(params.A_ptr)[dim_id]; + constexpr float kLog2e = M_LOG2E; + weight_t A_scaled = A_val * kLog2e; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks); + float dD_val = 0; + float ddelta_bias_val = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize; + Cvar += (params.n_chunks - 1) * kChunkSize; + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + { + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + auto &smem_load_weight_C = smem_load_weight1; + load_weight(Cvar, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (chunk % 2): threadIdx.x + 2] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * C_vals[i]; + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[(chunk + 1) % 2]) + : smem_delta_a[threadIdx.x + 1 + 2]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[chunk - 1] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[0] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[0] = postfix_op.running_prefix; } + weight_t dA_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); + dC_vals[i] = dout_vals[i] * thread_data[i].y; + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + auto &smem_exchange_C = smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + if (threadIdx.x == 0) { + smem_da[0] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[0]; + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNItems]; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + Bvar -= kChunkSize; + Cvar -= kChunkSize; + } + + if (params.dD_ptr != nullptr) { + __syncthreads(); + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + __syncthreads(); + if (threadIdx.x == 0) { gpuAtomicAdd(dA, smem_da[0]); } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t) + (kNThreads + 4) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..62f2039e5832ec1b06681a985a61a80d6d7d31ea --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel_ndstate.cuh" + +template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..4dc81499467fe8efa03d460de1e499397577c7a2 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel_ndstate.cuh" + +template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5d51e6ea13c4ba0d9776970aadf82c7c96873a4e --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh @@ -0,0 +1,200 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan_ndstate.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + constexpr float kLog2e = M_LOG2E; + weight_t A_val = reinterpret_cast(params.A_ptr)[dim_id] * kLog2e; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks; + + float D_val = 0; // attention! + if (params.D_ptr != nullptr) { + D_val = reinterpret_cast(params.D_ptr)[dim_id]; + } + float delta_bias = 0; + if (params.delta_bias_ptr != nullptr) { + delta_bias = reinterpret_cast(params.delta_bias_ptr)[dim_id]; + } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNItems], delta_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if (params.delta_softplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + delta_u_vals[i] = delta_vals[i] * u_val; + out_vals[i] = D_val * u_val; + } + + __syncthreads(); + { + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + auto &smem_load_weight_C = smem_load_weight1; + load_weight(Cvar, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); + __syncthreads(); + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[0] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[0] = prefix_op.running_prefix; + x[chunk] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + out_vals[i] += thread_data[i].y * C_vals[i]; + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize); + Bvar += kChunkSize; + Cvar += kChunkSize; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ddc2740ef2c63c3a9cbbbe0fac205144334dd44 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp @@ -0,0 +1,341 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan_ndstate.h" +#define MAX_DSTATE 256 + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +using weight_t = float; + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + void* dD_ptr, + void* ddelta_bias_ptr, + bool delta_softplus) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, n_groups, n_chunks, + u, delta, A, B, C, dout, + D_ptr, delta_bias_ptr, x_ptr, delta_softplus); + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + +} + +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int n_groups = B.size(1); + + TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim); + CHECK_SHAPE(B, batch_size, n_groups, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, 1 * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream); + }); + std::vector result = {out, x}; + return result; +} + +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int n_groups = B.size(1); + + TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim); + CHECK_SHAPE(B, batch_size, n_groups, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor out; + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * 1); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h new file mode 100644 index 0000000000000000000000000000000000000000..070642b2b68a45ffcaaf71a878a68c7b1752cb15 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h @@ -0,0 +1,84 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMScanParamsBase { + using index_t = uint32_t; + + int batch, seqlen, n_chunks; + index_t a_batch_stride; + index_t b_batch_stride; + index_t out_batch_stride; + + // Common data pointers. + void *__restrict__ a_ptr; + void *__restrict__ b_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, n_groups, n_chunks; + int dim_ngroups_ratio; + + bool delta_softplus; + + index_t A_d_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +struct SSMParamsBwd: public SSMParamsBase { + index_t dout_batch_stride; + index_t dout_d_stride; + index_t dA_d_stride; + index_t dB_batch_stride; + index_t dB_group_stride; + index_t dB_d_stride; + index_t dC_batch_stride; + index_t dC_group_stride; + index_t dC_d_stride; + index_t du_batch_stride; + index_t du_d_stride; + index_t ddelta_batch_stride; + index_t ddelta_d_stride; + + // Common data pointers. + void *__restrict__ dout_ptr; + void *__restrict__ dA_ptr; + void *__restrict__ dB_ptr; + void *__restrict__ dC_ptr; + void *__restrict__ dD_ptr; + void *__restrict__ du_ptr; + void *__restrict__ ddelta_ptr; + void *__restrict__ ddelta_bias_ptr; +}; diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..659e6c77136ce9e5ffe8c3de35ab881aeb2e182c --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_bwd_kernel_nrow.cuh @@ -0,0 +1,344 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && 3; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + // scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * (2 * Ktraits::MaxDState + kNThreads)); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int dim_id_nrow = dim_id * kNRows; + const int group_id = dim_id_nrow / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id_nrow * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id_nrow * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id_nrow * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id_nrow * params.A_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id_nrow * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id_nrow; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id_nrow; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id_nrow; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id_nrow; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize; + Cvar += (params.n_chunks - 1) * kChunkSize; + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNRows][kNItems]; + input_t delta_vals_load[kNRows][kNItems]; + input_t dout_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems]; + float du_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[r][i] = float(dout_vals_load[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr? 0: delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr? 0: D_val[r]) * dout_vals[r][i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); } + } + + float ddelta_vals[kNRows][kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + weight_t A_scaled[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + constexpr float kLog2e = M_LOG2E; + A_scaled[r] = A_val[r] * kLog2e; + } + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + auto &smem_load_weight_C = smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]); + thread_data[i] = make_float2(delta_a_exp, delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + if (i == 0) { + // smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads) : threadIdx.x + 2 * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)] = delta_a_exp; + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState : threadIdx.x + kNRows * 2 * Ktraits::MaxDState] = delta_a_exp; + + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[r][i] * C_vals[i]; + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + // ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)]) + // : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState + r * (2 * Ktraits::MaxDState + kNThreads)]; + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(r * params.n_chunks + chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = dx * B_vals[i]; + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + const float a = thread_data[i].y - (delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a; + dA_val += dx * delta_vals[r][i] * a; + dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); + dC_vals[i] = dout_vals[r][i] * thread_data[i].y; + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + auto &smem_exchange_C = smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + delta -= kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr? 0: delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[r][i] = delta_val <= 20.f + ? ddelta_vals[r][i] / (1.f + delta_val_neg_exp) + : ddelta_vals[r][i]; + } + } + } + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; } + } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id_nrow * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id_nrow * params.ddelta_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + Bvar -= kChunkSize; + Cvar -= kChunkSize; + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (params.dD_ptr != nullptr) { + __syncthreads(); + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); } + } + __syncthreads(); + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * kNRows * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..302adc5fda0307973bbef8cca29f5e7eb725e44d --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel_nrow.cuh" + +template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu new file mode 100644 index 0000000000000000000000000000000000000000..2d227de551da2d1ffcc5501963b1f8f7a6e4a128 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel_nrow.cuh" + +template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu new file mode 100644 index 0000000000000000000000000000000000000000..9c204ed2f8aea4a39a3117049a88db4c03c51281 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu @@ -0,0 +1,8 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel_nrow.cuh" + +template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu new file mode 100644 index 0000000000000000000000000000000000000000..c40f3a4e978489e7b29018fb24fbe3d7a640acdd --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu @@ -0,0 +1,8 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel_nrow.cuh" + +template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d0ce0d8e980fb90372e6957e11fb54b5078b7c8 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel_nrow.cuh" + +template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu new file mode 100644 index 0000000000000000000000000000000000000000..955e70201657910194c267046c93823d458a053e --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel_nrow.cuh" + +template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu new file mode 100644 index 0000000000000000000000000000000000000000..fc5c05ee369cb84caac5dda55ab477fad50020ca --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel_nrow.cuh" + +template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu new file mode 100644 index 0000000000000000000000000000000000000000..e078c2c0e7404551e879ce52347419eb1dd60c45 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel_nrow.cuh" + +template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..73741194eaf232933f8ac2f54f6944b2813ab7bd --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh @@ -0,0 +1,238 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int dim_id_nrow = dim_id * kNRows; + const int group_id = dim_id_nrow / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id_nrow * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id_nrow * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id_nrow * params.A_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * params.n_chunks * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id_nrow + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id_nrow + r]; + } + } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + auto &smem_load_weight_C = smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + B_vals[i] * delta_u_vals[r][i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + out_vals[r][i] += thread_data[i].y * C_vals[i]; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id_nrow * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + Bvar += kChunkSize; + Cvar += kChunkSize; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6752826ed7ac297b91a3716700ebf19d5dfb97d6 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_nrow.cpp @@ -0,0 +1,367 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" +#define MAX_DSTATE 256 + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +using weight_t = float; + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define INT_SWITCH(INT, NAME, ...) [&] { \ + if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ + else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ + else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ + else {constexpr int NAME = 1; __VA_ARGS__(); } \ +}() \ + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + params.B_dstate_stride = B.stride(2); + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + params.C_dstate_stride = C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + void* dD_ptr, + void* ddelta_bias_ptr, + bool delta_softplus) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, dout, + D_ptr, delta_bias_ptr, x_ptr, delta_softplus); + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + params.dB_dstate_stride = dB.stride(2); + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + params.dC_dstate_stride = dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + +} + +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = B.size(1); + + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + INT_SWITCH(nrows, kNRows, [&] { + selective_scan_fwd_cuda(params, stream); + }); + }); + std::vector result = {out, x}; + return result; +} + +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = B.size(1); + + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor out; + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + // constexpr int kNRows = 1; + INT_SWITCH(nrows, kNRows, [&] { + selective_scan_bwd_cuda(params, stream); + }); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh new file mode 100644 index 0000000000000000000000000000000000000000..352b3601fbaeff679c46eaad8ae8c7737fe400d8 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh @@ -0,0 +1,323 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + using output_t = output_t_; + + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int MaxDState = MAX_DSTATE; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && 3; + static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockLoadOutputT = cub::BlockLoad; + using BlockLoadOutputVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockLoadOutputT::TempStorage), + sizeof(typename BlockLoadOutputVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = 2 * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using output_t = typename Ktraits::output_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load1 = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (batch_id * params.dB_batch_stride + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (batch_id * params.dC_batch_stride + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + float dD_val = 0; + float ddelta_bias_val = 0; + + output_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + dim_id * params.dout_d_stride; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize; + Cvar += (params.n_chunks - 1) * kChunkSize; + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + float dout_vals[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + if constexpr (std::is_same_v) { + input_t dout_vals_load[kNItems]; + load_input(reinterpret_cast(dout), dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + Converter::to_float(dout_vals_load, dout_vals); + } else { + static_assert(std::is_same_v); + load_output(dout, dout_vals, smem_load1, params.seqlen - chunk * kChunkSize); + } + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float delta_vals[kNItems]; + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + constexpr float kLog2e = M_LOG2E; + weight_t A_val = A[state_idx * params.A_dstate_stride]; + weight_t A_scaled = A_val * kLog2e; + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + auto &smem_load_weight_C = smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState: threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * C_vals[i]; + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); + dC_vals[i] = dout_vals[i] * thread_data[i].y; + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + auto &smem_exchange_C = smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNItems]; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + Bvar -= kChunkSize; + Cvar -= kChunkSize; + } + + if (params.dD_ptr != nullptr) { + __syncthreads(); + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + __syncthreads(); + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, input_t, weight_t, output_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, input_t, weight_t, output_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, input_t, weight_t, output_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, input_t, weight_t, output_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, input_t, weight_t, output_t>(params, stream); + } +} + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..fbb117b96ddfa2eb0fef87e649d222c1a994bed1 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_bwd_kernel_oflex.cuh" + +template void selective_scan_bwd_cuda<1, float, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float, at::Half>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBwd ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..a781bf739a441d71dd3c80ae83e44a1cb6361224 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#include "selective_scan_fwd_kernel_oflex.cuh" + +template void selective_scan_fwd_cuda<1, float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream); + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh new file mode 100644 index 0000000000000000000000000000000000000000..84b87f32d4f6c838a0b963254615de5ba5f7d5da --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh @@ -0,0 +1,211 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + using output_t = output_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int MaxDState = MAX_DSTATE; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes; + static constexpr bool kDirectIOOutput = kDirectIO && (kNLoadsOutput == 1); + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + using BlockStoreOutputT = cub::BlockStore; + using BlockStoreOutputVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + 2 * sizeof(typename BlockLoadWeightT::TempStorage), + 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage), + sizeof(typename BlockStoreOutputT::TempStorage), + sizeof(typename BlockStoreOutputVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using output_t = typename Ktraits::output_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store1 = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate; + + float D_val = 0; // attention! + if (params.D_ptr != nullptr) { + D_val = reinterpret_cast(params.D_ptr)[dim_id]; + } + float delta_bias = 0; + if (params.delta_bias_ptr != nullptr) { + delta_bias = reinterpret_cast(params.delta_bias_ptr)[dim_id]; + } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNItems], delta_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if (params.delta_softplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + delta_u_vals[i] = delta_vals[i] * u_val; + out_vals[i] = D_val * u_val; + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + constexpr float kLog2e = M_LOG2E; + weight_t A_val = A[state_idx * params.A_dstate_stride]; + A_val *= kLog2e; + weight_t B_vals[kNItems], C_vals[kNItems]; + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize)); + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight1, (params.seqlen - chunk * kChunkSize)); + __syncthreads(); + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[chunk * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + out_vals[i] += thread_data[i].y * C_vals[i]; + } + } + + output_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output1(out, out_vals, smem_store1, params.seqlen - chunk * kChunkSize); + Bvar += kChunkSize; + Cvar += kChunkSize; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t, output_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t, output_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t, output_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, output_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t, output_t>(params, stream); + } +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..671bab67967a55b6cd2aac984bc53a7a98876b48 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_oflex.cpp @@ -0,0 +1,363 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" +#define MAX_DSTATE 256 + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +using weight_t = float; + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + params.B_dstate_stride = B.stride(2); + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + params.C_dstate_stride = C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + void* dD_ptr, + void* ddelta_bias_ptr, + bool delta_softplus) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, dout, + D_ptr, delta_bias_ptr, x_ptr, delta_softplus); + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + params.dB_dstate_stride = dB.stride(2); + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + params.dC_dstate_stride = dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + +} + +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + bool delta_softplus, + int nrows, + bool out_float + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = B.size(1); + + TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); + TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel + at::Tensor out = torch::empty({batch_size, dim, seqlen}, u.options().dtype(out_float? (at::ScalarType::Float): input_type)); + at::Tensor x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + if (!out_float) { + selective_scan_fwd_cuda<1, input_t, weight_t, input_t>(params, stream); + } else { + selective_scan_fwd_cuda<1, input_t, weight_t, float>(params, stream); + } + }); + std::vector result = {out, x}; + return result; +} + +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + auto output_type = dout.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + TORCH_CHECK(output_type == input_type || output_type == at::ScalarType::Float); + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == input_type); + TORCH_CHECK(C.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = B.size(1); + + TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); + TORCH_CHECK(dstate <= MAX_DSTATE, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor out; + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + u, delta, A, B, C, out, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + if (output_type == input_type) { + selective_scan_bwd_cuda<1, input_t, weight_t, input_t>(params, stream); + } else { + selective_scan_bwd_cuda<1, input_t, weight_t, float>(params, stream); + } + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b312df182d90b6d61035223b075691683cfaf9a9 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/reverse_scan.cuh @@ -0,0 +1,403 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +// #include +#include "uninitialized_copy.cuh" +#include "cub_extra.cuh" + +/** + * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { + static_assert(LENGTH > 0); + T retval = input[LENGTH - 1]; + #pragma unroll + for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } + return retval; +} + +/** + * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanInclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + T inclusive = postfix; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } +} + +/** + * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanExclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + // Careful, output maybe be aliased to input + T exclusive = postfix; + T inclusive; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + return inclusive; +} + + +/** + * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS ///< Number of threads per logical warp + > +struct WarpReverseScan { + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + /// Whether the logical warp size and the PTX warp size coincide + static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); + /// The number of warp scan steps + static constexpr int STEPS = cub::Log2::VALUE; + static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + explicit __device__ __forceinline__ + WarpReverseScan() + : lane_id(cub::LaneId()) + , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) + // , member_mask(cub::WarpMask(warp_id)) + , member_mask(WarpMask(warp_id)) + { + if (!IS_ARCH_WARP) { + lane_id = lane_id % LOGICAL_WARP_THREADS; + } + } + + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return cub::ShuffleIndex(input, src_lane, member_mask); + } + + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) { + int offset = 1 << STEP; + T temp = cub::ShuffleDown( + inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask + ); + // Perform scan op if from a valid peer + inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset + ? inclusive_output : scan_op(temp, inclusive_output); + } + } + + /// Exclusive scan + // Get exclusive from inclusive + template + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + T inclusive_output; + InclusiveReverseScan(input, inclusive_output, scan_op); + warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + + /** + * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. + */ + template + __device__ __forceinline__ void ReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + ScanOpT scan_op) ///< [in] Binary scan operator + { + InclusiveReverseScan(input, inclusive_output, scan_op); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + +}; + +/** + * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. + */ +template < + typename T, ///< Data type being scanned + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure + > +struct BlockReverseScan { + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + /// The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X; + + /// Layout type for padded thread block raking grid + using BlockRakingLayout = cub::BlockRakingLayout; + // The number of reduction elements is not a multiple of the number of raking threads for now + static_assert(BlockRakingLayout::UNGUARDED); + + /// Number of raking threads + static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; + /// Number of raking elements per warp synchronous raking thread + static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; + /// Cooperative work can be entirely warp synchronous + static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); + + /// WarpReverseScan utility type + using WarpReverseScan = WarpReverseScan; + + /// Shared memory storage layout type + struct _TempStorage { + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + T cached_segment[SEGMENT_LENGTH]; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /// Performs upsweep raking reduction, returning the aggregate + template + __device__ __forceinline__ T Upsweep(ScanOp scan_op) { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data into registers + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; + #pragma unroll + for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { + raking_partial = scan_op(raking_partial, cached_segment[i]); + } + return raking_partial; + } + + + /// Performs exclusive downsweep raking scan + template + __device__ __forceinline__ void ExclusiveDownsweep( + ScanOp scan_op, + T raking_partial) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data back into registers + if (!MEMOIZE) { + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + } + ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); + // Write data back to smem + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } + } + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockReverseScan( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) + {} + + + /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpReverseScan warp_scan; + warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); + // Obtain warp-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); + } else { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + detail::uninitialized_copy(placement_ptr, input); + cub::CTA_SYNC(); + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) { + WarpReverseScan warp_scan; + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + // Obtain block-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + // Update postfix with warpscan exclusive partial + T downsweep_postfix = linear_tid == RAKING_THREADS - 1 + ? block_postfix : scan_op(block_postfix, exclusive_partial); + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, downsweep_postfix); + } + cub::CTA_SYNC(); + // Grab thread postfix from shared memory + exclusive_output = *placement_ptr; + + // // Compute warp scan in each warp. + // // The exclusive output from the last lane in each warp is invalid. + // T inclusive_output; + // WarpReverseScan warp_scan; + // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); + + // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. + // T block_aggregate; + // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); + + // // Apply warp postfix to our lane's partial + // if (warp_id != 0) { + // exclusive_output = scan_op(warp_postfix, exclusive_output); + // if (lane_id == 0) { exclusive_output = warp_postfix; } + // } + + // // Use the first warp to determine the thread block postfix, returning the result in lane0 + // if (warp_id == 0) { + // T block_postfix = block_postfix_callback_op(block_aggregate); + // if (lane_id == 0) { + // // Share the postfix with all threads + // detail::uninitialized_copy(&temp_storage.block_postfix, + // block_postfix); + + // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 + // } + // } + + // cub::CTA_SYNC(); + + // // Incorporate thread block postfix into outputs + // T block_postfix = temp_storage.block_postfix; + // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } + } + } + + + /** + * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void InclusiveReverseScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. + { + // Reduce consecutive thread items in registers + T thread_postfix = ThreadReverseReduce(input, scan_op); + // Exclusive thread block-scan + ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); + // Inclusive scan in registers with postfix as seed + ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); + } + +}; \ No newline at end of file diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h new file mode 100644 index 0000000000000000000000000000000000000000..11ad9e8db33b98978bba1d69259f23dd3a726dd1 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan.h @@ -0,0 +1,90 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMScanParamsBase { + using index_t = uint32_t; + + int batch, seqlen, n_chunks; + index_t a_batch_stride; + index_t b_batch_stride; + index_t out_batch_stride; + + // Common data pointers. + void *__restrict__ a_ptr; + void *__restrict__ b_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +struct SSMParamsBwd: public SSMParamsBase { + index_t dout_batch_stride; + index_t dout_d_stride; + index_t dA_d_stride; + index_t dA_dstate_stride; + index_t dB_batch_stride; + index_t dB_group_stride; + index_t dB_d_stride; + index_t dB_dstate_stride; + index_t dC_batch_stride; + index_t dC_group_stride; + index_t dC_d_stride; + index_t dC_dstate_stride; + index_t du_batch_stride; + index_t du_d_stride; + index_t ddelta_batch_stride; + index_t ddelta_d_stride; + + // Common data pointers. + void *__restrict__ dout_ptr; + void *__restrict__ dA_ptr; + void *__restrict__ dB_ptr; + void *__restrict__ dC_ptr; + void *__restrict__ dD_ptr; + void *__restrict__ du_ptr; + void *__restrict__ ddelta_ptr; + void *__restrict__ ddelta_bias_ptr; +}; diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h new file mode 100644 index 0000000000000000000000000000000000000000..828becf8703f1fe084720661e63b6c930e849a4c --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/selective_scan_common.h @@ -0,0 +1,210 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For scalar_value_type + +#define MAX_DSTATE 256 + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} + +template +inline __device__ void store_output1(typename Ktraits::output_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreOutputT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::output_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockStoreOutputVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + Ktraits::BlockStoreOutputT(smem_store).Store(out, write_vals, seqlen); + } +} + +template +inline __device__ void load_output(typename Ktraits::output_t *u, + typename Ktraits::output_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadOutputT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadOutputVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadOutputT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/static_switch.h b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..7920ac045d0a2a1f4c4159ee3eebe51fe1e2c203 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/uninitialized_copy.cuh b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/uninitialized_copy.cuh new file mode 100644 index 0000000000000000000000000000000000000000..630622dddcc9041737307810000584a843a01764 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/csrc/selective_scan/uninitialized_copy.cuh @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + + +namespace detail +{ + +#if defined(_NVHPC_CUDA) +template +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + // NVBug 3384810 + new (ptr) T(::cuda::std::forward(val)); +} +#else +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + *ptr = ::cuda::std::forward(val); +} + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + new (ptr) T(::cuda::std::forward(val)); +} +#endif + +} // namespace detail diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/PKG-INFO b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..0532dc90c2b9709aaec62c2773c5e2e68a8caf98 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/PKG-INFO @@ -0,0 +1,17 @@ +Metadata-Version: 2.1 +Name: selective-scan +Version: 0.0.2 +Summary: selective scan +Home-page: https://github.com/state-spaces/mamba +Author: Tri Dao, Albert Gu, $@#Anonymous#@$ +Author-email: tri@tridao.me, agu@cs.cmu.edu, $@#Anonymous#EMAIL@$ +License: UNKNOWN +Platform: UNKNOWN +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: Unix +Requires-Python: >=3.7 +Description-Content-Type: text/markdown + +UNKNOWN + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/SOURCES.txt b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..5e4f8e0f283c429344ee83a3e0a3f6c18f70add3 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/SOURCES.txt @@ -0,0 +1,10 @@ +README.md +setup.py +csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu +csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu +csrc/selective_scan/cusoflex/selective_scan_oflex.cpp +selective_scan.egg-info/PKG-INFO +selective_scan.egg-info/SOURCES.txt +selective_scan.egg-info/dependency_links.txt +selective_scan.egg-info/requires.txt +selective_scan.egg-info/top_level.txt \ No newline at end of file diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/dependency_links.txt b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/requires.txt b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..744bf2277cb49650bb81faa16ee0ebbfcf119dda --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/requires.txt @@ -0,0 +1,4 @@ +torch +packaging +ninja +einops diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/top_level.txt b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..97034d0b47249738a9c0f339065b51a434da8579 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/selective_scan.egg-info/top_level.txt @@ -0,0 +1 @@ +selective_scan_cuda_oflex diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/setup.py b/rscd/models/backbones/lib_mamba/kernels/selective_scan/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..5e117a7a2c9782625e9cb8b02d022aa14a619a83 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/setup.py @@ -0,0 +1,167 @@ +# Modified by $@#Anonymous#@$ #20240123 +# Copyright (c) 2023, Albert Gu, Tri Dao. +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform +import shutil + +from setuptools import setup, find_packages +import subprocess +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FORCE_CXX11_ABI", "FALSE") == "TRUE" + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + +MODES = ["oflex"] +# MODES = ["core", "ndstate", "oflex"] +# MODES = ["core", "ndstate", "oflex", "nrow"] + +def get_ext(): + cc_flag = [] + + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + print("\n\nCUDA_HOME = {}\n\n".format(CUDA_HOME)) + + # Check, if CUDA11 is installed for compute capability 8.0 + multi_threads = True + gencode_sm90 = False + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + print("CUDA version: ", bare_metal_version, flush=True) + if bare_metal_version >= Version("11.8"): + gencode_sm90 = True + if bare_metal_version < Version("11.6"): + warnings.warn("CUDA version ealier than 11.6 may leads to performance mismatch.") + if bare_metal_version < Version("11.2"): + multi_threads = False + + cc_flag.extend(["-gencode", "arch=compute_70,code=sm_70"]) + cc_flag.extend(["-gencode", "arch=compute_80,code=sm_80"]) + if gencode_sm90: + cc_flag.extend(["-gencode", "arch=compute_90,code=sm_90"]) + if multi_threads: + cc_flag.extend(["--threads", "4"]) + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + sources = dict( + core=[ + "csrc/selective_scan/cus/selective_scan.cpp", + "csrc/selective_scan/cus/selective_scan_core_fwd.cu", + "csrc/selective_scan/cus/selective_scan_core_bwd.cu", + ], + nrow=[ + "csrc/selective_scan/cusnrow/selective_scan_nrow.cpp", + "csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu", + "csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu", + ], + ndstate=[ + "csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp", + "csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu", + "csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu", + ], + oflex=[ + "csrc/selective_scan/cusoflex/selective_scan_oflex.cpp", + "csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu", + "csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu", + ], + ) + + names = dict( + core="selective_scan_cuda_core", + nrow="selective_scan_cuda_nrow", + ndstate="selective_scan_cuda_ndstate", + oflex="selective_scan_cuda_oflex", + ) + + ext_modules = [ + CUDAExtension( + name=names.get(MODE, None), + sources=sources.get(MODE, None), + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + + cc_flag + }, + include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + ) + for MODE in MODES + ] + + return ext_modules + +ext_modules = get_ext() +setup( + name="selective_scan", + version="0.0.2", + packages=[], + author="Tri Dao, Albert Gu, $@#Anonymous#@$ ", + author_email="tri@tridao.me, agu@cs.cmu.edu, $@#Anonymous#EMAIL@$", + description="selective scan", + long_description="", + long_description_content_type="text/markdown", + url="https://github.com/state-spaces/mamba", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"bdist_wheel": _bdist_wheel, "build_ext": BuildExtension} if ext_modules else {"bdist_wheel": _bdist_wheel,}, + python_requires=">=3.7", + install_requires=[ + "torch", + "packaging", + "ninja", + "einops", + ], +) diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan.py b/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..862e8eef28123b9dced64e379091c8196738ba3f --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan.py @@ -0,0 +1,505 @@ +# Modified by $@#Anonymous#@$ #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn.functional as F +import pytest +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from einops import rearrange, repeat +import time +from functools import partial + +SSOFLEX_FLOAT = True + + +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None): + MODE = mode + + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() + + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + elif MODE in ["ssoflex"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, SSOFLEX_FLOAT) + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + elif MODE in ["sscorendstate"]: + assert A.shape[-1] == 1 and B.shape[2] == 1 and C.shape[2] == 1 + A = A.view(-1) + B = B.squeeze(2) + C = C.squeeze(2) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore", "ssoflex"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore", "ssoflex"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + elif MODE in ["sscorendstate"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 + ) + dA = dA.unsqueeze(1) + dB = dB.unsqueeze(2) + dC = dC.unsqueeze(2) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + outs = SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + if mode in ["ssoflex"]: + return outs.to(u.dtype) if not return_last_state else (outs[0].to(u.dtype), outs[1]) + else: + return outs + + selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" + + return selective_scan_fn + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +def selective_scan_ref_v2(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + A = A.to(dtype_in) + B = B.to(dtype_in) + C = C.to(dtype_in) + D = D.to(dtype_in) if D is not None else None + z = z.to(dtype_in) if z is not None else None + delta = delta.to(dtype_in) if delta is not None else None + delta_bias = delta_bias.to(dtype_in) if delta_bias is not None else None + + if delta_bias is not None: + delta = delta + delta_bias[..., None] + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B, "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C, "... (L two) -> ... L two", two=2)) + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state.float()) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, *args, **kwargs): + return selective_scan_ref_v2(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + +# MODE = None +# MODE = "mamba_ssm" +# MODE = "sscore" +# MODE = "ssoflex" +# MODE = "sstest" +# MODE = "mamba_ssm_sscore" # 1344 items pass +# MODE = "mamba_ssm_sscorendstate" # 1344 items pass +MODE = "mamba_ssm_ssoflex" # 1344 items pass + +if MODE in ["mamba_ssm"]: + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["ssoflex"]: + import selective_scan_cuda_oflex + selective_scan_cuda = selective_scan_cuda_oflex + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_oflex, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["sscore"]: + import selective_scan_cuda_core + selective_scan_cuda = selective_scan_cuda_core + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["sstest"]: + import selective_scan_cuda_test + selective_scan_cuda = selective_scan_cuda_test + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["mamba_ssm_sscore"]: + import selective_scan_cuda_core + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscore") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +elif MODE in ["mamba_ssm_sstest"]: + import selective_scan_cuda_test + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode="sstest") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +elif MODE in ["mamba_ssm_sscorendstate"]: + import selective_scan_cuda_core + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscorendstate") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +elif MODE in ["mamba_ssm_ssoflex"]: + import selective_scan_cuda_oflex + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_oflex, mode="ssoflex") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +else: + selective_scan_cuda = None + + +print("use MODE:", MODE) +DSTATE = [1] +DIM = [768] +BATCHSIZE = [2] +# DSTATE = [1] if MODE in ["mamba_ssm_sscorendstate", "sscorendstate"] else [8] +NROWS = [1,2,3,4] +IDTYPE = MODE in [None] + +# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [False, True]) +@pytest.mark.parametrize('delta_softplus', [False, True]) +# @pytest.mark.parametrize('has_z', [False, True]) +@pytest.mark.parametrize('has_z', [False]) +@pytest.mark.parametrize('has_D', [False, True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +# @pytest.mark.parametrize("is_variable_C", [False, True]) +@pytest.mark.parametrize("is_variable_C", [True]) +# @pytest.mark.parametrize("is_variable_B", [False, True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("nrows", NROWS) +@pytest.mark.parametrize("batch_size", BATCHSIZE) +@pytest.mark.parametrize("dim", DIM) +@pytest.mark.parametrize("dstate", DSTATE) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, nrows, batch_size, dim, dstate): + wtype = itype if IDTYPE else wtype + print(f'method: {selective_scan_cuda}') + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + # batch_size = 2 + # dim = 24 + # dstate = 8 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state, nrows=nrows + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + +# test_selective_scan(True, True, 2, True, False, True, True, True, 64, torch.float32, torch.float32, 1, 2, 24, 1) + diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan_easy.py b/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan_easy.py new file mode 100644 index 0000000000000000000000000000000000000000..9247d66a66a2c116b71af38666eeae3844fbfe64 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan_easy.py @@ -0,0 +1,1064 @@ +# Modified by $@#Anonymous#@$ #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. +import math +from functools import partial +import torch +import torch.nn.functional as F +import pytest +from einops import rearrange, repeat + +MODE = "v0" +# MODE = "fn" +# MODE = "fnDEBUG" + +def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + """ + # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen + us: B, G * D, L + dts: B, G * D, L + As: G * D, N + Bs: B, G, N, L + Cs: B, G, N, L + Ds: G * D + delta_bias: G * D + # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small + """ + def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix): + """ + partial(h) / partial(t) = Ah + Bu; y = Ch + Du; + => partial(h*exp(-At)) / partial(t) = Bu*exp(-At); + => h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv}; + => h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i}); + y_i = C_i*h_i + D*u_i + """ + """ + us, dts: (L, B, G, D) # L is chunk_size + As: (G, D, N) + Bs, Cs: (L, B, G, N) + Ds: (G, D) + hprefix: (B, G, D, N) + """ + ts = dts.cumsum(dim=0) + Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() + # scale = Ats[-1].detach() + scale = 1 + rAts = Ats / scale + duts = dts * us + dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs) + hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0) + hs = hs_tmp + Ats * hprefix.unsqueeze(0) + ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs) + return ys, hs + + + dtype = torch.float32 + # dtype = torch.float16 + inp_dtype = us.dtype + has_D = Ds is not None + if chunksize < 1: + chunksize = Bs.shape[-1] + + dts = dts.to(dtype) + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + B, G, N, L = Bs.shape + us = us.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + As = As.view(G, -1, N).to(dtype) + Bs = Bs.permute(3, 0, 1, 2).to(dtype) + Cs = Cs.permute(3, 0, 1, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + D = As.shape[1] + + oys = [] + hprefix = us.new_zeros((B, G, D, N), dtype=dtype) + for i in range(0, L, chunksize): + ys, hs = selective_scan_chunk( + us[i:i + chunksize], dts[i:i + chunksize], + As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix, + ) + oys.append(ys) + hprefix = hs[-1] + + oys = torch.cat(oys, dim=0) + if has_D: + oys = oys + Ds * us + oys = oys.permute(1, 2, 3, 0).view(B, -1, L) + + # return oys, hprefix.view(B, G * D, N) + return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix.view(B, G * D, N).float()) + + +class SelectiveScanEasy(torch.autograd.Function): + # for debug, we use it as an orinary object + DEBUG = (MODE == "fnDEBUG") + + if DEBUG: + print("DEBUG here...", flush=True) + saved_tensors = [] + + @classmethod + def save_for_backward(ctx, *args): + ctx.saved_tensors = args + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(ctx, us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + has_D = Ds is not None + dtype = torch.float32 + + dts = dts.to(dtype) + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + B_squeeze = (len(Bs.shape) == 3) + C_squeeze = (len(Cs.shape) == 3) + if B_squeeze: + Bs = Bs.unsqueeze(1) + if C_squeeze: + Cs = Cs.unsqueeze(1) + B, G, N, L = Bs.shape + us = us.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + As = As.view(G, -1, N).to(dtype) + Bs = Bs.permute(3, 0, 1, 2).to(dtype) + Cs = Cs.permute(3, 0, 1, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + D = As.shape[1] + + ctx.shape = (B, G, D, N, L) + ctx.delta_softplus = delta_softplus + ctx.return_last_state = return_last_state + ctx.chunksize = chunksize + ctx.BC_squeeze = (B_squeeze, C_squeeze) + save_for_backward = [us, dts, As, Bs, Cs, Ds, delta_bias] + + chunks = list(range(0, L, chunksize)) + oys = [] + ohs = [] + hprefix = us.new_zeros((B, G, D, N), dtype=torch.float) + for i in chunks: + ts = dts[i:i+chunksize].cumsum(dim=0) + Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() + # scale = Ats[-1:].detach() + scale = 1 + rAts = Ats / scale + duts = dts[i:i + chunksize] * us[i:i + chunksize] + dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs[i:i + chunksize]) + tmp_dtBus_div_rAts = (dtBus / rAts) + tmp_dtBus_div_rAts_cumsum = tmp_dtBus_div_rAts.cumsum(dim=0) + hs = rAts * tmp_dtBus_div_rAts_cumsum + Ats * hprefix.unsqueeze(0) + ys = torch.einsum("lbgn,lbgdn->lbgd", Cs[i:i + chunksize], hs) + oys.append(ys) + ohs.append(hs) + hprefix = hs[-1] + + oys = torch.cat(oys, dim=0) + ohs = torch.cat(ohs, dim=0) + if has_D: + oys = oys + Ds * us + + save_for_backward.extend([ohs]) + + ctx.save_for_backward(*save_for_backward) + + oys = oys.permute(1, 2, 3, 0).view(B, -1, L) + + if getattr(ctx, "DEBUG", False): + print("DEBUG here ..............", flush=True) + oys.backward = partial(ctx.backward, ctx) + + return oys, hprefix.view(B, G * D, N) + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, doys: torch.Tensor, *args): + DEBUG = getattr(ctx, "DEBUG", False) + us, dts, As, Bs, Cs, Ds, delta_bias, ohs = ctx.saved_tensors + + B, G, D, N, L = ctx.shape + chunksize = ctx.chunksize + delta_softplus = ctx.delta_softplus + doys = doys.view(B, G, D, L).permute(3, 0, 1, 2) + + def rev_comsum(x): + cum_sum = torch.cumsum(x, dim=0) + return (x - cum_sum + cum_sum[-1:None]) + + if DEBUG: + dtype = torch.float32 + us = us.requires_grad_() + dts = dts.requires_grad_() + As = As.requires_grad_() + Bs = Bs.requires_grad_() + Cs = Cs.requires_grad_() + Ds = Ds.requires_grad_() if Ds is not None else None + delta_bias = delta_bias.requires_grad_() if delta_bias is not None else None + ohs = ohs.requires_grad_() + + # copy forward again + if DEBUG: + has_D = Ds is not None + + tmp_fwd_dtBus = [] + tmp_fwd_rAts = [] + tmp_fwd_Ats = [] + tmp_fwd_dtBus_div_rAts_cumsum = [] + tmp_fwd_dtBus_div_rAts = [] + + chunks = list(range(0, L, chunksize)) + oys = [] + ohs = [] + hprefix = us.new_zeros((B, G, D, N), dtype=torch.float) + for i in chunks: + ts = dts[i:i+chunksize].cumsum(dim=0) + Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() + # scale = Ats[-1:].detach() + scale = 1 + rAts = Ats / scale + duts = dts[i:i + chunksize] * us[i:i + chunksize] + dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs[i:i + chunksize]) + tmp_dtBus_div_rAts = (dtBus / rAts) + tmp_dtBus_div_rAts_cumsum = tmp_dtBus_div_rAts.cumsum(dim=0) + hs = rAts * tmp_dtBus_div_rAts_cumsum + Ats * hprefix.unsqueeze(0) + ys = torch.einsum("lbgn,lbgdn->lbgd", Cs[i:i + chunksize], hs) + oys.append(ys) + ohs.append(hs) + hprefix = hs[-1] + + tmp_fwd_dtBus_div_rAts_cumsum.append(tmp_dtBus_div_rAts_cumsum) + tmp_fwd_dtBus_div_rAts.append(tmp_dtBus_div_rAts) + tmp_fwd_dtBus.append(dtBus) + tmp_fwd_rAts.append(rAts) + tmp_fwd_Ats.append(Ats) + + oys = torch.cat(oys, dim=0) + ohs = torch.cat(ohs, dim=0) + if has_D: + oys = oys + Ds * us + + if DEBUG: + _oys = oys.requires_grad_() + + dus = None + dDs = None + if Ds is not None: + dDs = torch.einsum("lbgd,lbgd->gd", doys, us).view(-1) + dus = torch.einsum("lbgd,gd->lbgd", doys, Ds) + + chunks = list(range(0, L, chunksize)) + dAs = us.new_zeros((G, D, N), dtype=torch.float) + dus = us.new_zeros((L, B, G, D), dtype=torch.float) if dus is None else dus + ddts = us.new_zeros((L, B, G, D), dtype=torch.float) + dBs = us.new_zeros((L, B, G, N), dtype=torch.float) + dCs = us.new_zeros((L, B, G, N), dtype=torch.float) + dhprefix = us.new_zeros((B, G, D, N), dtype=torch.float) + for i in chunks[::-1]: + # forward procedure ================ + tmp_dts = dts[i:i+chunksize] + ts = tmp_dts.cumsum(dim=0) + Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() + scale = Ats[-1].detach() + scale = 1 + rAts = Ats / scale + duts = dts[i:i + chunksize] * us[i:i + chunksize] + dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs[i:i + chunksize]) + dtBus_div_rAts = (dtBus / rAts) + hs_minus_prefix_div_rAts = dtBus_div_rAts.cumsum(dim=0) + + # hs_minus_prefix = rAts * hs_minus_prefix_div_rAts + + # below code is not ok due to precision limitation... + # use saved hs instead + if False: + hprefix = (hsuffix - hs_minus_prefix[-1]) / scale + hs = hs_minus_prefix + Ats * hprefix.unsqueeze(0) + + # backward procedure ================ + _hs = ohs[i:i+chunksize] + _hprefix = ohs[i - 1] if i > 0 else None + dCs[i:i + chunksize] = torch.einsum("lbgd,lbgdn->lbgn", doys[i:i + chunksize], _hs) + dhs = doys[i:i + chunksize].unsqueeze(4) * Cs[i:i + chunksize].unsqueeze(3) # lbgd,lbgn->lbgdn + dhs[-1] = dhs[-1] + dhprefix + dhprefix = torch.einsum("lbgdn,lbgdn -> bgdn", dhs, Ats) + dAts_hprefix = dhs * _hprefix.unsqueeze_(0) if i > 0 else None # lbgdn,bgdn->lbgdn + drAts_hs_minus_prefix = dhs * hs_minus_prefix_div_rAts + dhs_minus_prefix_div_rAts = dhs * rAts + + if DEBUG: + print("1", (torch.autograd.grad(_oys, tmp_fwd_dtBus_div_rAts_cumsum[chunks.index(i)], doys, create_graph=True, allow_unused=True)[0] - dhs_minus_prefix_div_rAts).abs().sum()) + + d_dtBus_div_rAts = rev_comsum(dhs_minus_prefix_div_rAts) + if DEBUG: + d_dtBus_div_rAts_v1 = torch.autograd.grad(hs_minus_prefix_div_rAts, dtBus_div_rAts, dhs_minus_prefix_div_rAts, create_graph=True, allow_unused=True)[0] + print("2.0", (torch.autograd.grad(_oys, tmp_fwd_dtBus_div_rAts[chunks.index(i)], doys, create_graph=True, allow_unused=True)[0] - d_dtBus_div_rAts).abs().sum()) + print("2.1", (d_dtBus_div_rAts - d_dtBus_div_rAts_v1).abs().sum()) + d_dtBus_div_rAts = d_dtBus_div_rAts_v1 + + ddtBus = d_dtBus_div_rAts / rAts + dBs[i:i + chunksize] = torch.einsum("lbgdn,lbgd->lbgn", ddtBus, duts) + dduts = torch.einsum("lbgdn,lbgn->lbgd", ddtBus, Bs[i:i + chunksize]) + dus[i:i + chunksize] = dus[i:i + chunksize] + dduts * dts[i:i + chunksize] + if DEBUG: + print("3", (torch.autograd.grad(_oys, tmp_fwd_dtBus[chunks.index(i)], doys, create_graph=True, allow_unused=True)[0] - ddtBus).abs().sum()) + + if DEBUG: + tmp_a = torch.randn((L, B, G, D, N)).to(dtype).cuda().requires_grad_() + tmp_b = torch.cumsum(tmp_a, dim=0) + tmp_c = torch.randn((L, B, G, D, N)).to(dtype).cuda() + print("ex.0", (torch.autograd.grad(tmp_b, tmp_a, tmp_c, create_graph=True, allow_unused=True)[0] - rev_comsum(tmp_c)).abs().sum()) + + drAts_dtBus_div_rAts = d_dtBus_div_rAts * (-dtBus_div_rAts / rAts) + if DEBUG: + drAts_dtBus_div_rAts_v1 = d_dtBus_div_rAts * (dtBus / -(rAts * rAts)) # do not use this!!! + drAts_dtBus_div_rAts_ref = torch.autograd.grad(dtBus_div_rAts, rAts, d_dtBus_div_rAts, create_graph=True, allow_unused=True)[0] + print("4.0", (drAts_dtBus_div_rAts - drAts_dtBus_div_rAts_ref).abs().sum()) + print("4.0_v1", (drAts_dtBus_div_rAts - drAts_dtBus_div_rAts_v1).abs().sum()) + + ddts[i:i + chunksize] = dduts * us[i:i + chunksize] + dAts = drAts_dtBus_div_rAts / scale + drAts_hs_minus_prefix / scale + (dAts_hprefix if i > 0 else 0) + + if DEBUG: + drAts_ref = torch.autograd.grad(_oys, tmp_fwd_rAts[chunks.index(i)], doys, create_graph=True, allow_unused=True)[0] + dAts_ref = torch.autograd.grad(_oys, tmp_fwd_Ats[chunks.index(i)], doys, create_graph=True, allow_unused=True)[0] + print("4.1", (drAts_ref - (drAts_dtBus_div_rAts + drAts_hs_minus_prefix)).abs().sum()) + print("4.2", ((drAts_ref - (drAts_dtBus_div_rAts + drAts_hs_minus_prefix)) / scale).abs().sum()) + print("4.3", (dAts_ref - dAts).abs().sum()) + + dAts_noexp = dAts * Ats # d(e^x) = e^x * dx + dAs = dAs + torch.einsum("lbgdn,lbgd->gdn", dAts_noexp, ts) + d_ts = torch.einsum("lbgdn,gdn->lbgd", dAts_noexp, As) + + _part_ddts = rev_comsum(d_ts) # the precision is enough + if DEBUG: + _part_ddts_v1 = torch.autograd.grad(ts, tmp_dts, d_ts, create_graph=True, allow_unused=True)[0] + print("5.0", (_part_ddts - _part_ddts_v1).abs().sum()) + + ddts[i:i + chunksize] = ddts[i:i + chunksize] + _part_ddts + + if DEBUG: + print("f", (torch.autograd.grad(_oys, dts, doys, create_graph=True, allow_unused=True)[0] - ddts).abs().sum(), flush=True) + + if delta_softplus: + # softplus = log(1 + e^x); dsoftplus = e^x /(1+e^-x) = 1 - 1 / (1+e^x) = 1 - (e^-softplus) + ddts = ddts - ddts * (-dts).exp() + + ddelta_bias = None + if delta_bias is not None: + ddelta_bias = ddts.sum([0, 1]).view(-1) + + if DEBUG: + print("f", (torch.autograd.grad(_oys, us, doys, create_graph=True, allow_unused=True)[0] - dus).abs().sum(), flush=True) + print("f", (torch.autograd.grad(_oys, Bs, doys, create_graph=True, allow_unused=True)[0] - dBs).abs().sum(), flush=True) + print("f", (torch.autograd.grad(_oys, Cs, doys, create_graph=True, allow_unused=True)[0] - dCs).abs().sum(), flush=True) + print("f", (torch.autograd.grad(_oys, Ds, doys, create_graph=True, allow_unused=True)[0].view(-1) - dDs).abs().sum(), flush=True) + print("f", (torch.autograd.grad(_oys, As, doys, create_graph=True, allow_unused=True)[0] - dAs).abs().sum(), flush=True) + # print("f", (torch.autograd.grad(_oys, delta_bias, doys, create_graph=True, allow_unused=True)[0] - ddelta_bias).abs().sum(), flush=True) + + dus = dus.permute(1, 2, 3, 0).view(B, -1, L) + ddts = ddts.permute(1, 2, 3, 0).view(B, -1, L) + dAs = dAs.view(-1, N) + dBs = dBs.permute(1, 2, 3, 0) + dCs = dCs.permute(1, 2, 3, 0) + if ctx.BC_squeeze[0]: + dBs = dBs.flatten(1, 2) + if ctx.BC_squeeze[1]: + dCs = dCs.flatten(1, 2) + + return dus, ddts, dAs, dBs, dCs, dDs, ddelta_bias, None, None, None + + +def selective_scan_easy_fwdbwd(u, delta, A, B, C, D, delta_bias=None, delta_softplus=None, + return_last_state=False, chunksize=64): + mode = MODE + if mode in ["fnDEBUG"]: + outs = SelectiveScanEasy.forward(SelectiveScanEasy, u, delta, A, B, C, D, delta_bias, delta_softplus, return_last_state, chunksize) + return (outs[0].to(u.dtype), *outs[1:]) if return_last_state else outs[0].to(u.dtype) + else: + outs = SelectiveScanEasy.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, return_last_state, chunksize) + return (outs[0].to(u.dtype), *outs[1:]) if return_last_state else outs[0].to(u.dtype) + + +def selective_scan_easyv2(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + if chunksize < 1: + chunksize = Bs.shape[-1] + mask = torch.tril(torch.ones((chunksize, chunksize), device=us.device), diagonal=0) + + def ss_chunk(us, dts, As, Bs, Cs, h0, mask): + # BHLD, BHLN, HND, BHDN + cL = us.shape[2] + _mask = (mask[:cL,:cL].contiguous() if cL < chunksize else mask).view(1, 1, cL, cL, 1) + + w_log = As[None, :, None, :, :] * (torch.cumsum(dts, dim=2)[..., None, :]) # (B, H, L, Dk, Dv) + v = us * dts # (B,H,L,Dv) + k = Bs # (B,H,L,Dk) + q = Cs # (B,H,L,Dk) + w = w_log.exp() + + k_div_w = k[..., None] / w + q_mul_w = q[..., None] * w + + # h0 independent ==================== + next_h_1 = w[:,:,-1] * torch.einsum("bhlkv,bhlv->bhkv", (k_div_w), v) + y_1 = torch.einsum("bhlrv,bhrv->bhlv", torch.einsum("bhlkv,bhrkv->bhlrv", q_mul_w, k_div_w) * _mask, v) + + # h0 dependent ====================== + y_0 = torch.einsum("bhlkv,bhkv->bhlv", q_mul_w, h0) + next_h_0 = w[:,:, -1] * h0 + + next_h = next_h_0 + next_h_1 + y = y_0 + y_1 + + return y, next_h + + dtype = torch.float32 + # dtype = torch.float16 + inp_dtype = us.dtype + has_D = Ds is not None + dts = dts.to(dtype) + + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + + B, GD, L = us.shape + B, G, N, L = Bs.shape + D = GD // G + us = us.view(B, G, -1, L).permute(0, 1, 3, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(0, 1, 3, 2).to(dtype) + As = As.view(G, D, N).permute(0, 2, 1).to(dtype) + Bs = Bs.permute(0, 1, 3, 2).to(dtype) + Cs = Cs.permute(0, 1, 3, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + + oys = [] + hprefix = us.new_zeros((B, G, N, D), dtype=dtype) + for i in range(0, L, chunksize): + ys, hprefix = ss_chunk( + us[:,:, i:i + chunksize], dts[:,:, i:i + chunksize], + As, Bs[:,:, i:i + chunksize], Cs[:,:, i:i + chunksize], hprefix, mask, + ) + oys.append(ys) + + oys = torch.cat(oys, dim=2) + if has_D: + oys = oys + Ds.view(1, G, 1, D) * us + oys = oys.permute(0, 1, 3, 2).contiguous().view(B, -1, L) + hprefix = hprefix.permute(0, 1, 3, 2).contiguous().view(B, GD, N).float() + + return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix) + + +def selective_scan_easyv3(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + inv_ln2 = 1.44269504 + if chunksize < 0: + chunksize = 64 + chunksize = min(chunksize, Bs.shape[-1]) + # chunksize = Bs.shape[-1] + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + + B, GD, L = us.shape + B, G, N, L = Bs.shape + D = GD // G + + # mask triu ============== + _arange = torch.arange(0, chunksize, dtype=torch.int8, device=Bs.device) + _row_arange = _arange[None, :] # ((0, 1), (0, 1)) + _col_arange = _arange[:, None] # ((0, 0), (1, 1)) + # _mask_triu = tl.where(_row_arange >= _col_arange, 1, 0) + # _mask_tril = tl.where(_row_arange <= _col_arange, 1, 0) + _mask_tril = (_col_arange >= _row_arange).float() + + def cut_chunk(us, dts, Bs, Cs, chunksize=chunksize): + B, H, L, D = us.shape + B, H, L, N = Bs.shape + NT = -(L // -chunksize) + to_pad = NT * chunksize - L + _pad = lambda x: torch.nn.functional.pad(x.view(B * H, L, -1), (0,0,0,to_pad,0,0)).view(B * H, NT, chunksize, x.shape[-1]) + us, dts, Bs, Cs = _pad(us), _pad(dts), _pad(Bs), _pad(Cs) + return us, dts, Bs, Cs + + def ss_chunk_h1y1(qs, ks, vs, ws=None, As=None, ts=None, dts=None, mask=None, scale=1): + # C = n_chunks, M = B * H, E = B * H * C, T = L / C + # MCTN, MCTN, MCTD; MCTND; HND, MCTD, MCTD; + if ws is None: + if ts is None: + ts = torch.cumsum(dts * inv_ln2, dim=2) + _ts = ts.view(-1, As.shape[0], *ts.shape[1:])[:, :, :, :, None, :] + ws = torch.exp2(As[None, :, None, None, :, :] * _ts).flatten(0, 1) # MCND + q_mul_w = qs[...,None] * ws * scale + k_div_w = ks[...,None] / ws + qwkw = torch.einsum("mctnd,mcrnd->mctrd", q_mul_w, k_div_w) + qwkw = qwkw * mask[None, None, :, :, None] + y1 = torch.einsum("mctrd,mcrd->mctd", qwkw, vs) + ht1 = ws[:,:,-1,:,:] * (k_div_w * vs[...,None,:]).sum(dim=-3) + cws = ws[:,:,-1,:,:] + return ht1, y1, ws, cws, q_mul_w # MCND, MCTD, MCTND, MCND, MCTND + + def ss_chunk_h(cws, ht1): + device, dtype = ht1.device, ht1.dtype + M, C, N, D = ht1.shape + hts = [torch.zeros((M, N, D), device=device, dtype=dtype)] + inith = hts[0] + for c in range(C): + inith = cws[:, c] * inith + ht1[:, c] + hts.append(inith) + return torch.stack(hts, dim=1) # M(C+1)ND + + def ss_chunk_y(y1, hs, q_mul_w): + iniths = hs[:,:-1,:,:].contiguous() + y0 = torch.einsum("mctnd,mcnd->mctd", q_mul_w, iniths) + y = y0 + y1 + return y + + def ss_chunk_h1y1_dk1(qs, ks, vs, ws=None, As=None, ts=None, dts=None, mask=None, scale=1): + # C = n_chunks, M = B * H, E = B * H * C, T = L / C + # MCTN, MCTN, MCTD; MCTND; HND, MCTD, MCTD; + M, C, T, N = qs.shape + assert N == 1 + if ws is None: + if ts is None: + ts = torch.cumsum(dts, dim=2) + _ts = ts.view(-1, As.shape[0], *ts.shape[1:])[:, :, :, :, None, :] + ws = (As[None, :, None, None, :, :] * _ts).exp().flatten(0, 1) # MCND + q_mul_w = qs[...,None] * ws * scale + # k_div_w = ks[...,None] / ws + v_div_w = vs / ws[:, :, :, 0, :] # MCTD + + y1 = ws[:,:,:,0,:] * torch.einsum("mctr,mcrd->mctd", qs[:,:,:,None,0] * ks[:,:,None,:,0] * mask[None, None, :, :], v_div_w) + ht1 = (ws[:,:,-1,0,:] * (ks * v_div_w).sum(dim=-2))[:, :, None, :] + cws = ws[:,:,-1,:,:] + return ht1, y1, ws, cws, q_mul_w # MCND, MCTD, MCTND, MCND, MCTND + + def ss_chunk_y_dk1(y1, hs, q_mul_w): + iniths = hs[:,:-1,:,:].contiguous() + y0 = q_mul_w[:, :, :, 0, :] * iniths + y = y0 + y1 + return y + + if N == 1: + ss_chunk_h1y1 = ss_chunk_h1y1_dk1 + ss_chunk_y = ss_chunk_y_dk1 + + dtype = torch.float32 + # dtype = torch.float16 + inp_dtype = us.dtype + has_D = Ds is not None + dts = dts.to(dtype) + + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + us = us.view(B, G, -1, L).permute(0, 1, 3, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(0, 1, 3, 2).to(dtype) + As = As.view(G, D, N).permute(0, 2, 1).to(dtype) + Bs = Bs.permute(0, 1, 3, 2).to(dtype) + Cs = Cs.permute(0, 1, 3, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + + _us, dts, Bs, Cs = cut_chunk(us, dts, Bs, Cs, chunksize=chunksize) + ht1, y1, ws, cws, q_mul_w = ss_chunk_h1y1(Cs, Bs, _us * dts, None, As, None, dts, mask=_mask_tril) + hts = ss_chunk_h(cws, ht1) # M(C+1)ND + oys = ss_chunk_y(y1, hts, q_mul_w) # MCTD + oys = oys.contiguous().view(B, G, -1, D)[:, :, :L, :].contiguous() + hprefix = hts[:,-1,:,:].contiguous() # MND + + if has_D: + oys = oys + Ds.view(1, G, 1, D) * us + oys = oys.permute(0, 1, 3, 2).contiguous().view(B, -1, L) + hprefix = hprefix.permute(0, 2, 1).contiguous().view(B, GD, N).float() + + return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix) + + +class SelectiveScanMatrix(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(ctx, us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + save_for_backward = [] + if chunksize < 1: + chunksize = Bs.shape[-1] + mask = torch.tril(torch.ones((chunksize, chunksize), device=us.device), diagonal=0) + + def ss_chunk(us, dts, As, Bs, Cs, h0, mask): + # BHLD, BHLN, HND, BHDN + cL = us.shape[2] + _mask = (mask[:cL,:cL].contiguous() if cL < mask.shape[0] else mask).view(1, 1, cL, cL, 1) + + w_log = As[None, :, None, :, :] * (torch.cumsum(dts, dim=2)[..., None, :]) # (B, H, L, Dk, Dv) + v = us * dts # (B,H,L,Dv) + k = Bs # (B,H,L,Dk) + q = Cs # (B,H,L,Dk) + w = w_log.exp() + + k_div_w = k[..., None] / w + q_mul_w = q[..., None] * w + + # h0 independent ==================== + next_h_1 = w[:,:,-1] * torch.einsum("bhlkv,bhlv->bhkv", k_div_w, v) + y_1 = torch.einsum("bhlrv,bhrv->bhlv", torch.einsum("bhlkv,bhrkv->bhlrv", q_mul_w, k_div_w) * _mask, v) + + # h0 dependent ====================== + y_0 = torch.einsum("bhlkv,bhkv->bhlv", q_mul_w, h0) + next_h_0 = w[:,:, -1] * h0 + + next_h = next_h_0 + next_h_1 + y = y_0 + y_1 + + return y, next_h + + dtype = torch.float32 + # dtype = torch.float16 + inp_dtype = us.dtype + has_D = Ds is not None + dts = dts.to(dtype) + + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + + B, GD, L = us.shape + B, G, N, L = Bs.shape + D = GD // G + us = us.view(B, G, -1, L).permute(0, 1, 3, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(0, 1, 3, 2).to(dtype) + As = As.view(G, D, N).permute(0, 2, 1).to(dtype) + Bs = Bs.permute(0, 1, 3, 2).to(dtype) + Cs = Cs.permute(0, 1, 3, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + ctx.shape = (B, G, L, N, D) + + hprefix = us.new_zeros((B, G, N, D), dtype=dtype) + oys = [] + ohs = [hprefix] + for i in range(0, L, chunksize): + ys, hprefix = ss_chunk( + us[:,:, i:i + chunksize], dts[:,:, i:i + chunksize], + As, Bs[:,:, i:i + chunksize], Cs[:,:, i:i + chunksize], hprefix, mask, + ) + oys.append(ys) + ohs.append(hprefix) + + oys = torch.cat(oys, dim=2) + if has_D: + oys = oys + Ds.view(1, G, 1, D) * us + oys = oys.permute(0, 1, 3, 2).contiguous().view(B, -1, L) + hprefix = hprefix.permute(0, 1, 3, 2).contiguous().view(B, GD, N).float() + ohs = torch.stack(ohs, dim=2) # (B,H,LC,K,V) + + ctx.chunksize = chunksize + ctx.delta_softplus = delta_softplus + save_for_backward.extend([mask, us, dts, As, Bs, Cs, Ds, delta_bias, ohs]) + ctx.save_for_backward(*save_for_backward) + + return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix) + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, doys: torch.Tensor, *args): + mask, us, dts, As, Bs, Cs, Ds, delta_bias, ohs = ctx.saved_tensors + + B, G, L, N, D = ctx.shape + chunksize = ctx.chunksize + delta_softplus = ctx.delta_softplus + doys = doys.view(B, G, D, L).permute(0, 1, 3, 2) + def rev_comsum_dim_2(x): + cum_sum = torch.cumsum(x, dim=2) + return (x - cum_sum + cum_sum[:,:,-1:None]) + + dus = None + dDs = None + if Ds is not None: + dDs = torch.einsum("bgld,bgld->gd", doys, us).view(-1) + dus = torch.einsum("bgld,gd->bgld", doys, Ds) + + chunks = list(range(0, L, chunksize)) + dAs = us.new_zeros((G, N, D), dtype=torch.float) + dus = us.new_zeros((B, G, L, D), dtype=torch.float) if dus is None else dus + ddts = us.new_zeros((B, G, L, D), dtype=torch.float) + dBs = us.new_zeros((B, G, L, N), dtype=torch.float) + dCs = us.new_zeros((B, G, L, N), dtype=torch.float) + dhprefix = us.new_zeros((B, G, N, D), dtype=torch.float) + + ohs_ptr = -2 + for i in chunks[::-1]: + h0 = ohs[:,:, ohs_ptr] + ohs_ptr = ohs_ptr - 1 + # forward procedure ================ + # BHLD, BHLN, HND, BHDN + cus = us[:,:,i:i + chunksize] + cdts = dts[:,:,i:i + chunksize] + cBs = Bs[:,:,i:i + chunksize] + cCs = Cs[:,:,i:i + chunksize] + cdoys = doys[:,:,i:i + chunksize] + cL = cus.shape[2] + _mask = (mask[:cL,:cL].contiguous() if cL < chunksize else mask).view(1, 1, cL, cL, 1) + + ts = torch.cumsum(cdts, dim=2) + w_log = As[None, :, None, :, :] * (ts[..., None, :]) # (B, H, L, Dk, Dv) + v = cus * cdts # (B,H,L,Dv) + k = cBs # (B,H,L,Dk) + q = cCs # (B,H,L,Dk) + w = w_log.exp() + + k_div_w = k[..., None] / w + q_mul_w = q[..., None] * w + + # h0 independent ================ + next_h_1_tmp = torch.einsum("bhlkv,bhlv->bhkv", k_div_w, v) + # next_h_1 = w[:,:,-1] * next_h_1_tmp + y_1_tmp = torch.einsum("bhlkv,bhrkv->bhlrv", q_mul_w, k_div_w) * _mask + # y_1 = torch.einsum("bhlrv,bhrv->bhlv", y_1_tmp, v) + + # h0 dependent ================ + # next_h_0 = w[:,:, -1] * h0 + # y_0 = torch.einsum("bhlkv,bhkv->bhlv", q_mul_w, h0) + # next_h = next_h_0 + next_h_1 + # y = y_0 + y_1 + + # backward procedure ================ + d_k, d_v, d_cus = None, None, None # only h0 independent + d_h0 = None # only h0 dependent + # h0 independent (start from y1, h1) ================ + # involves q, k, v=(cus,cdts), w=(As, cdts) + if True: + d_v_y1 = torch.einsum("bhlv,bhlrv->bhrv", cdoys, y_1_tmp) + d_y1tmp_y1 = torch.einsum("bhlv,bhrv->bhlrv", cdoys, v) * _mask + d_qmulw_y1 = torch.einsum("bhlrv,bhrkv->bhlkv", d_y1tmp_y1, k_div_w) + d_kdivw_y1 = torch.einsum("bhlrv,bhlkv->bhrkv", d_y1tmp_y1, q_mul_w) + + # d_v_nexth1 = torch.einsum("bhkv,bhlkv->bhlv", dhprefix, next_h_1_tmp) + # d_nexth1tmp_nexth1 = torch.einsum("bhkv,bhlv->bhlkv", dhprefix, v) + # d_kdivw_nexth1 = d_nexth1tmp_nexth1 * w[:,:,-1:] + # d_wf1_nexth1 = torch.einsum("bhlkv,bhlkv->bhkv", d_nexth1tmp_nexth1, k_div_w) + + d_wf1_nexth1 = dhprefix * next_h_1_tmp + d_nexth1tmp_nexth1 = dhprefix * w[:, :, -1] + d_kdivw_nexth1 = torch.einsum("bhkv,bhlv->bhlkv", d_nexth1tmp_nexth1, v) + d_v_nexth1 = torch.einsum("bhkv,bhlkv->bhlv", d_nexth1tmp_nexth1, k_div_w) + + + d_q_qmulw_y1 = torch.einsum("bhlkv,bhlkv->bhlk", d_qmulw_y1, w) + d_w_qmulw_y1 = torch.einsum("bhlkv,bhlk->bhlkv", d_qmulw_y1, q) + d_kdivw = d_kdivw_y1 + d_kdivw_nexth1 + d_k = torch.einsum("bhlkv->bhlk", d_kdivw / w) + # c'=(a(b^-1))'=(-a(b^-2))=(-c(b^-1)) + d_w_kdivw = d_kdivw * (-k_div_w / w) + + d_w_h0i = d_w_qmulw_y1 + d_w_kdivw + d_w_h0i[:, :, -1] += d_wf1_nexth1 + d_wlog_h0i = d_w_h0i * w + d_ts_wlog_h0i = torch.einsum("bhlkv,hkv->bhlv", d_wlog_h0i, As) + d_As_h0i = torch.einsum("bhlkv,bhlv->hkv", d_wlog_h0i, ts) + d_cdts_ts_h0i = rev_comsum_dim_2(d_ts_wlog_h0i) + + d_v = d_v_y1 + d_v_nexth1 + d_cdts_v_h0i = d_v * cus + d_cus = d_v * cdts + d_cdts_h0i = d_cdts_ts_h0i + d_cdts_v_h0i + + d_q_h0i = d_q_qmulw_y1 + + # h0 dependent (start from y0, h0) ================ + # involves q,w=(As, cdts),h0 + if True: + d_h0_y0 = torch.einsum("bhlv,bhlkv->bhkv", cdoys, q_mul_w) + d_qmulw_y0 = torch.einsum("bhlv,bhkv->bhlkv", cdoys, h0) + d_h0_nexth0 = dhprefix * w[:,:,-1] + d_wf1_nexth0 = dhprefix * h0 + + d_h0 = d_h0_y0 + d_h0_nexth0 + + d_q_h0d = torch.einsum("bhlkv,bhlkv->bhlk", d_qmulw_y0, w) + d_w_h0d = torch.einsum("bhlkv,bhlk->bhlkv", d_qmulw_y0, q) + d_w_h0d[:, :, -1] += d_wf1_nexth0 + d_wlog_h0d = d_w_h0d * w + d_ts_wlog_h0d = torch.einsum("bhlkv,hkv->bhlv", d_wlog_h0d, As) + d_As_h0d = torch.einsum("bhlkv,bhlv->hkv", d_wlog_h0d, ts) + d_cdts_h0d = rev_comsum_dim_2(d_ts_wlog_h0d) + + # store gradient + dus[:, :, i:i + chunksize] += d_cus + ddts[:, :, i:i + chunksize] = (d_cdts_h0i + d_cdts_h0d) + dAs += (d_As_h0i + d_As_h0d) + dBs[:, :, i:i + chunksize] = d_k + dCs[:, :, i:i + chunksize] = (d_q_h0i + d_q_h0d) + dhprefix = d_h0 + + if delta_softplus: + # softplus = log(1 + e^x); dsoftplus = e^x /(1+e^-x) = 1 - 1 / (1+e^x) = 1 - (e^-softplus) + ddts = ddts - ddts * (-dts).exp() + + ddelta_bias = None + if delta_bias is not None: + ddelta_bias = ddts.sum([0, 2]) + ddelta_bias = ddelta_bias.view(-1) + + dAs = dAs.permute(0, 2, 1).contiguous().view(-1, N) + dus = dus.permute(0, 1, 3, 2).contiguous().view(B, -1, L) + ddts = ddts.permute(0, 1, 3, 2).contiguous().view(B, -1, L) + dBs = dBs.permute(0, 1, 3, 2).contiguous() + dCs = dCs.permute(0, 1, 3, 2).contiguous() + + return dus, ddts, dAs, dBs, dCs, dDs, ddelta_bias, None, None, None + + +def selective_scan_easyv2_fwdbwd(u, delta, A, B, C, D, delta_bias=None, delta_softplus=None, + return_last_state=False, chunksize=64): + outs = SelectiveScanMatrix.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, return_last_state, chunksize) + return (outs[0].to(u.dtype), *outs[1:]) if return_last_state else outs[0].to(u.dtype) + + +selective_scan_easy = selective_scan_easy +# selective_scan_easy = selective_scan_easy_fwdbwd +selective_scan_easy = selective_scan_easyv2 +# selective_scan_easy = selective_scan_easyv2_fwdbwd +selective_scan_easy = selective_scan_easyv3 + +from ssmtriton import selective_scan_easyv3 as ss +selective_scan_easy = ss + +# api to fit original mamba_ssm +def build_api_selective_scan(chunksize=64): + def selective_scan_fn(u, delta, A, B, C, D, z=None, + delta_bias=None, delta_softplus=None, + return_last_state=False): + assert z is None + return selective_scan_easy(u, delta, A, B, C, D, delta_bias, delta_softplus, return_last_state, chunksize) + return selective_scan_fn + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +@pytest.mark.parametrize('wtype', [torch.float32]) +# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [65, 128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [False, True]) +# @pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [False, True]) +# @pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [False]) +@pytest.mark.parametrize('has_D', [False, True]) +# @pytest.mark.parametrize('has_D', [True]) +# @pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("varBC_groups", [2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("chunksize", [64]) +# @pytest.mark.parametrize("chunksize", [32]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, chunksize): + selective_scan_fn = build_api_selective_scan(chunksize=chunksize) + + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 18 + dstate = 8 + dstate = 1 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + + +# pytest test_selective_scan.py + +if __name__ == "__main__": + test_selective_scan(True, True, 3, True, False, True, True, True, 1011, torch.float32, torch.float32, 64) + # test_selective_scan(True, True, 3, True, False, True, True, True, 5, torch.float32, torch.float32, 64) diff --git a/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan_speed.py b/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan_speed.py new file mode 100644 index 0000000000000000000000000000000000000000..76345830fcf78cb5a90eb8b2870f61c215070883 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/kernels/selective_scan/test_selective_scan_speed.py @@ -0,0 +1,519 @@ +# Modified by $@#Anonymous#@$ #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn.functional as F +import pytest +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from einops import rearrange, repeat +import time +from functools import partial + + +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None): + MODE = mode + + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() + + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + elif MODE in ["sscorendstate"]: + assert A.shape[-1] == 1 and B.shape[2] == 1 and C.shape[2] == 1 + A = A.view(-1) + B = B.squeeze(2) + C = C.squeeze(2) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + elif MODE in ["sscorendstate"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 + ) + dA = dA.unsqueeze(1) + dB = dB.unsqueeze(2) + dC = dC.unsqueeze(2) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + + selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" + + return selective_scan_fn + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +def selective_scan_easy_v2(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=3): + """ + # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen + us: B, G * D, L + dts: B, G * D, L + As: G * D, N + Bs: B, G, N, L + Cs: B, G, N, L + Ds: G * D + delta_bias: G * D + # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small + """ + def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix, Mask): + """ + partial(h) / partial(t) = Ah + Bu; y = Ch + Du; + => partial(h*exp(-At)) / partial(t) = Bu*exp(-At); + => h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv}; + => h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i}); + y_i = C_i*h_i + D*u_i + """ + """ + us, dts: (L, B, G, D) # L is chunk_size + As: (G, D, N) + Bs, Cs: (L, B, G, N) + Ds: (G, D) + hprefix: (B, G, D, N) + """ + LChunk = us.shape[0] + ts = dts.cumsum(dim=0) + K_chunk_tmp = torch.einsum("gdn,lbgd->lbgdn", -As, ts) + K_chunk = K_chunk_tmp.exp() + Ats = (-K_chunk_tmp).exp() + Q_chunk = torch.einsum("lbgd,lbgn->lbgdn", dts * us, Bs) + if LChunk < chunksize: + Qm_chunk = Mask[:LChunk, :LChunk, None, None, None, None] * Q_chunk.unsqueeze(0).repeat(LChunk, 1, 1, 1, 1, 1) + else: + Qm_chunk = Mask[:, :, None, None, None, None] * Q_chunk.unsqueeze(0).repeat(LChunk, 1, 1, 1, 1, 1) + H_tmp = Ats * torch.einsum("rlbgdn,lbgdn->rbgdn", Qm_chunk, K_chunk) + hs = H_tmp + Ats * hprefix.unsqueeze(0) + ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs) + return ys, hs + + dtype = torch.float32 + # dtype = torch.float16 + inp_dtype = us.dtype + has_D = Ds is not None + if chunksize < 1: + chunksize = Bs.shape[-1] + Mask = torch.tril(us.new_ones((chunksize, chunksize))) + + dts = dts.to(dtype) + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + B, G, N, L = Bs.shape + us = us.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + As = As.view(G, -1, N).to(dtype) + Bs = Bs.permute(3, 0, 1, 2).to(dtype) + Cs = Cs.permute(3, 0, 1, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + D = As.shape[1] + + oys = [] + hprefix = us.new_zeros((B, G, D, N), dtype=dtype) + for i in range(0, L, chunksize): + ys, hs = selective_scan_chunk( + us[i:i + chunksize], dts[i:i + chunksize], + As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix, Mask + ) + oys.append(ys) + hprefix = hs[-1] + + oys = torch.cat(oys, dim=0) + if has_D: + oys = oys + Ds * us + oys = oys.permute(1, 2, 3, 0).view(B, -1, L) + + # return oys, hprefix.view(B, G * D, N) + return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix.view(B, G * D, N).float()) + + +def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): + """ + # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen + us: B, G * D, L + dts: B, G * D, L + As: G * D, N + Bs: B, G, N, L + Cs: B, G, N, L + Ds: G * D + delta_bias: G * D + # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small + """ + def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix): + """ + partial(h) / partial(t) = Ah + Bu; y = Ch + Du; + => partial(h*exp(-At)) / partial(t) = Bu*exp(-At); + => h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv}; + => h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i}); + y_i = C_i*h_i + D*u_i + """ + """ + us, dts: (L, B, G, D) # L is chunk_size + As: (G, D, N) + Bs, Cs: (L, B, G, N) + Ds: (G, D) + hprefix: (B, G, D, N) + """ + ts = dts.cumsum(dim=0) + Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() + # scale = Ats[-1].detach() + scale = 1 + rAts = Ats / scale + duts = dts * us + dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs) + hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0) + hs = hs_tmp + Ats * hprefix.unsqueeze(0) + ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs) + return ys, hs + + + dtype = torch.float32 + # dtype = torch.float16 + inp_dtype = us.dtype + has_D = Ds is not None + if chunksize < 1: + chunksize = Bs.shape[-1] + + dts = dts.to(dtype) + if delta_bias is not None: + dts = dts + delta_bias.view(1, -1, 1).to(dtype) + if delta_softplus: + dts = torch.nn.functional.softplus(dts) + + if len(Bs.shape) == 3: + Bs = Bs.unsqueeze(1) + if len(Cs.shape) == 3: + Cs = Cs.unsqueeze(1) + B, G, N, L = Bs.shape + us = us.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype) + As = As.view(G, -1, N).to(dtype) + Bs = Bs.permute(3, 0, 1, 2).to(dtype) + Cs = Cs.permute(3, 0, 1, 2).to(dtype) + Ds = Ds.view(G, -1).to(dtype) if has_D else None + D = As.shape[1] + + oys = [] + hprefix = us.new_zeros((B, G, D, N), dtype=dtype) + for i in range(0, L, chunksize): + ys, hs = selective_scan_chunk( + us[i:i + chunksize], dts[i:i + chunksize], + As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix, + ) + oys.append(ys) + hprefix = hs[-1] + + oys = torch.cat(oys, dim=0) + if has_D: + oys = oys + Ds * us + oys = oys.permute(1, 2, 3, 0).view(B, -1, L) + + # return oys, hprefix.view(B, G * D, N) + return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix.view(B, G * D, N).float()) + + +from test_selective_scan_easy import selective_scan_easyv3 +selective_scan_easy = selective_scan_easyv3 +from ssmtriton import selective_scan_easyv3 +selective_scan_easy_v2 = selective_scan_easyv3 + +def test_speed(): + MODE = "sscore" + # MODE = "sscorendstate" + wtype = torch.float32 + itype = torch.float32 + itype = torch.float16 + is_variable_B = True + is_variable_C = True + has_D = True + has_z = False # sscore not support z + has_delta_bias = True + varBC_groups = 2 + seqlen = 4096 + # seqlen = 128 + # seqlen = 64 + batch_size = 128 + dim = 24 + dim = 96 + # dim = 384 + # dim = 768 + dstate = 8 + dstate = 1 + # dstate = 24 + delta_softplus = True + device = 'cuda' + TIMES = 100 + import selective_scan_cuda_core + import selective_scan_cuda + # copied from test_selective_scan ====================== + torch.random.manual_seed(0) + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + # ================================ + starts = [] + ends = [] + tests = [ + partial(build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm", tag="ori"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True), + partial(selective_scan_easy, u, delta, A, B, C, D, delta_bias, delta_softplus, return_last_state=True), + partial(selective_scan_easy_v2, u, delta, A, B, C, D, delta_bias, delta_softplus, return_last_state=True), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f2b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode=MODE, tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), + partial(build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm", tag="ori"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True), + ] + + for test in tests: + s = time.time() + for _ in range(TIMES): + with torch.no_grad(): + test() + torch.cuda.synchronize() + torch.cuda.empty_cache() + e = time.time() + starts.append(s) + ends.append(e) + print("fwd", test.func.__repr__(), e - s, flush=True) + for test in tests: + s = time.time() + for _ in range(TIMES): + outs = test() + outs[0].sum().backward() + torch.cuda.synchronize() + torch.cuda.empty_cache() + e = time.time() + starts.append(s) + ends.append(e) + print("fwdbwd", test.func.__repr__(), e - s, flush=True) + +test_speed() diff --git a/rscd/models/backbones/lib_mamba/vmamba.py b/rscd/models/backbones/lib_mamba/vmamba.py new file mode 100644 index 0000000000000000000000000000000000000000..281c726fa0c6da59e90cc2f049dfff599b60cd13 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/vmamba.py @@ -0,0 +1,1848 @@ +import os +import time +import math +import copy +from functools import partial +from typing import Optional, Callable, Any +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_ +from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count + +DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" +# train speed is slower after enabling this opts. +# torch.backends.cudnn.enabled = True +# torch.backends.cudnn.benchmark = True +# torch.backends.cudnn.deterministic = True + +try: + from .csm_triton import cross_scan_fn, cross_merge_fn +except: + from csm_triton import cross_scan_fn, cross_merge_fn + +try: + from .csms6s import selective_scan_fn, selective_scan_flop_jit +except: + from csms6s import selective_scan_fn, selective_scan_flop_jit + +# FLOPs counter not prepared fro mamba2 +# try: +# from .mamba2.ssd_minimal import selective_scan_chunk_fn +# except: +# from mamba2.ssd_minimal import selective_scan_chunk_fn + + +# ===================================================== +# we have this class as linear and conv init differ from each other +# this function enable loading from both conv2d or linear +class Linear2d(nn.Linear): + def forward(self, x: torch.Tensor): + # B, C, H, W = x.shape + return F.conv2d(x, self.weight[:, :, None, None], self.bias) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape) + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor): + x = x.permute(0, 2, 3, 1) + x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +class PatchMerging2D(nn.Module): + def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False): + super().__init__() + self.dim = dim + Linear = Linear2d if channel_first else nn.Linear + self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last + self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) + self.norm = norm_layer(4 * dim) + + @staticmethod + def _patch_merging_pad_channel_last(x: torch.Tensor): + H, W, _ = x.shape[-3:] + if (W % 2 != 0) or (H % 2 != 0): + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + return x + + @staticmethod + def _patch_merging_pad_channel_first(x: torch.Tensor): + H, W = x.shape[-2:] + if (W % 2 != 0) or (H % 2 != 0): + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2] # ... H/2 W/2 + x1 = x[..., 1::2, 0::2] # ... H/2 W/2 + x2 = x[..., 0::2, 1::2] # ... H/2 W/2 + x3 = x[..., 1::2, 1::2] # ... H/2 W/2 + x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C + return x + + def forward(self, x): + x = self._patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) + + return x + + +class Permute(nn.Module): + def __init__(self, *args): + super().__init__() + self.args = args + + def forward(self, x: torch.Tensor): + return x.permute(*self.args) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + Linear = Linear2d if channels_first else nn.Linear + self.fc1 = Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class gMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): + super().__init__() + self.channel_first = channels_first + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + Linear = Linear2d if channels_first else nn.Linear + self.fc1 = Linear(in_features, 2 * hidden_features) + self.act = act_layer() + self.fc2 = Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor): + x = self.fc1(x) + x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) + x = self.fc2(x * self.act(z)) + x = self.drop(x) + return x + + +class SoftmaxSpatial(nn.Softmax): + def forward(self, x: torch.Tensor): + if self.dim == -1: + B, C, H, W = x.shape + return super().forward(x.view(B, C, -1)).view(B, C, H, W) + elif self.dim == 1: + B, H, W, C = x.shape + return super().forward(x.view(B, -1, C)).view(B, H, W, C) + else: + raise NotImplementedError + + +# ===================================================== +class mamba_init: + @staticmethod + def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4): + dt_proj = nn.Linear(dt_rank, d_inner, bias=True) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + # dt_proj.bias._no_reinit = True + + return dt_proj + + @staticmethod + def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): + # S4D real initialization + A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + if copies > 0: + A_log = A_log[None].repeat(copies, 1, 1).contiguous() + if merge: + A_log = A_log.flatten(0, 1) + A_log = nn.Parameter(A_log) + A_log._no_weight_decay = True + return A_log + + @staticmethod + def D_init(d_inner, copies=-1, device=None, merge=True): + # D "skip" parameter + D = torch.ones(d_inner, device=device) + if copies > 0: + D = D[None].repeat(copies, 1).contiguous() + if merge: + D = D.flatten(0, 1) + D = nn.Parameter(D) # Keep in fp32 + D._no_weight_decay = True + return D + + @classmethod + def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4): + # dt proj ============================ + dt_projs = [ + cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor) + for _ in range(k_group) + ] + dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank) + dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner) + del dt_projs + + # A, D ======================================= + A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) + Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D) + return A_logs, Ds, dt_projs_weight, dt_projs_bias + + +# support: v0, v0seq +class SS2Dv0: + def __initv0__( + self, + # basic dims =========== + d_model=96, + d_state=16, + ssm_ratio=2.0, + dt_rank="auto", + # ====================== + dropout=0.0, + # ====================== + seq=False, + force_fp32=True, + **kwargs, + ): + if "channel_first" in kwargs: + assert not kwargs["channel_first"] + act_layer = nn.SiLU + dt_min = 0.001 + dt_max = 0.1 + dt_init = "random" + dt_scale = 1.0 + dt_init_floor = 1e-4 + bias = False + conv_bias = True + d_conv = 3 + k_group = 4 + factory_kwargs = {"device": None, "dtype": None} + super().__init__() + d_inner = int(ssm_ratio * d_model) + dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + + self.forward = self.forwardv0 + if seq: + self.forward = partial(self.forwardv0, seq=True) + if not force_fp32: + self.forward = partial(self.forwardv0, force_fp32=False) + + # in proj ============================ + self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias) + self.act: nn.Module = act_layer() + self.conv2d = nn.Conv2d( + in_channels=d_inner, + out_channels=d_inner, + groups=d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + + # x proj ============================ + self.x_proj = [ + nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False) + for _ in range(k_group) + ] + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) + del self.x_proj + + # dt proj, A, D ============================ + self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D( + d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4, + ) + + # out proj ======================================= + self.out_norm = nn.LayerNorm(d_inner) + self.out_proj = nn.Linear(d_inner, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() + + def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs): + x = self.in_proj(x) + x, z = x.chunk(2, dim=-1) # (b, h, w, d) + z = self.act(z) + x = x.permute(0, 3, 1, 2).contiguous() + x = self.conv2d(x) # (b, d, h, w) + x = self.act(x) + selective_scan = partial(selective_scan_fn, backend="mamba") + + B, D, H, W = x.shape + D, N = self.A_logs.shape + K, D, R = self.dt_projs_weight.shape + L = H * W + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) + if hasattr(self, "x_proj_bias"): + x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) + + xs = xs.view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().view(B, -1, L) # (b, k * d, l) + Bs = Bs.contiguous() # (b, k, d_state, l) + Cs = Cs.contiguous() # (b, k, d_state, l) + + As = -self.A_logs.float().exp() # (k * d, d_state) + Ds = self.Ds.float() # (k * d) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 + # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 + to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) + + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + if seq: + out_y = [] + for i in range(4): + yi = selective_scan( + xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i], + As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i], + delta_bias=dt_projs_bias.view(K, -1)[i], + delta_softplus=True, + ).view(B, -1, L) + out_y.append(yi) + out_y = torch.stack(out_y, dim=1) + else: + out_y = selective_scan( + xs, dts, + As, Bs, Cs, Ds, + delta_bias=dt_projs_bias, + delta_softplus=True, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y + + y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) + y = self.out_norm(y).view(B, H, W, -1) + + y = y * z + out = self.dropout(self.out_proj(y)) + return out + + +# support: v01-v05; v051d,v052d,v052dc; +# postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32; +# history support: v2,v3;v31d,v32d,v32dc; +class SS2Dv2: + def __initv2__( + self, + # basic dims =========== + d_model=96, + d_state=16, + ssm_ratio=2.0, + dt_rank="auto", + act_layer=nn.SiLU, + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v0", + # ====================== + forward_type="v05", + channel_first=False, + # ====================== + **kwargs, + ): + factory_kwargs = {"device": None, "dtype": None} + super().__init__() + d_inner = int(ssm_ratio * d_model) + dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + self.channel_first = channel_first + self.with_dconv = d_conv > 1 + Linear = Linear2d if channel_first else nn.Linear + self.forward = self.forwardv2 + + # tags for forward_type ============================== + checkpostfix = self.checkpostfix + self.disable_force32, forward_type = checkpostfix("_no32", forward_type) + self.oact, forward_type = checkpostfix("_oact", forward_type) + self.disable_z, forward_type = checkpostfix("_noz", forward_type) + self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type) + self.out_norm, forward_type = self.get_outnorm(forward_type, d_inner, channel_first) + + # forward_type debug ======================================= + FORWARD_TYPES = dict( + v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True), + v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"), + v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"), + v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d" + v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d" + # =============================== + v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"), + v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"), + v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"), + # =============================== + v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"), + v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"), + ) + self.forward_core = FORWARD_TYPES.get(forward_type, None) + k_group = 4 + + # in proj ======================================= + d_proj = d_inner if self.disable_z else (d_inner * 2) + self.in_proj = Linear(d_model, d_proj, bias=bias) + self.act: nn.Module = act_layer() + + # conv ======================================= + if self.with_dconv: + self.conv2d = nn.Conv2d( + in_channels=d_inner, + out_channels=d_inner, + groups=d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + + # x proj ============================ + self.x_proj = [ + nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False) + for _ in range(k_group) + ] + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) + del self.x_proj + + # out proj ======================================= + self.out_act = nn.GELU() if self.oact else nn.Identity() + self.out_proj = Linear(d_inner, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() + + if initialize in ["v0"]: + self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D( + d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4, + ) + elif initialize in ["v1"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) + self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((k_group, d_inner, dt_rank))) # 0.1 is added in 0430 + self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, d_inner))) # 0.1 is added in 0430 + elif initialize in ["v2"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) + self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank))) + self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner))) + + def forward_corev2( + self, + x: torch.Tensor=None, + # ============================== + force_fp32=False, # True: input fp32 + # ============================== + ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input + no_einsum=False, # replace einsum with linear or conv1d to raise throughput + # ============================== + selective_scan_backend = None, + # ============================== + scan_mode = "cross2d", + scan_force_torch = False, + # ============================== + **kwargs, + ): + assert scan_mode in ["unidi", "bidi", "cross2d", "cascade2d"] + assert selective_scan_backend in [None, "oflex", "core", "mamba", "torch"] + delta_softplus = True + out_norm = self.out_norm + channel_first = self.channel_first + to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) + + B, D, H, W = x.shape + D, N = self.A_logs.shape + K, D, R = self.dt_projs_weight.shape + L = H * W + _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode] + + def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): + return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend) + + if _scan_mode == 3: + x_proj_bias = getattr(self, "x_proj_bias", None) + def scan_rowcol( + x: torch.Tensor, + proj_weight: torch.Tensor, + proj_bias: torch.Tensor, + dt_weight: torch.Tensor, + dt_bias: torch.Tensor, # (2*c) + _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state) + _Ds: torch.Tensor, + width = True, + ): + # x: (B, D, H, W) + # proj_weight: (2 * D, (R+N+N)) + XB, XD, XH, XW = x.shape + if width: + _B, _D, _L = XB * XH, XD, XW + xs = x.permute(0, 2, 1, 3).contiguous() + else: + _B, _D, _L = XB * XW, XD, XH + xs = x.permute(0, 3, 1, 2).contiguous() + xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W) + if no_einsum: + x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2) + dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2) + dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2) + else: + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight) + if x_proj_bias is not None: + x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight) + + xs = xs.view(_B, -1, _L) + dts = dts.contiguous().view(_B, -1, _L) + As = _As.view(-1, N).to(torch.float) + Bs = Bs.contiguous().view(_B, 2, N, _L) + Cs = Cs.contiguous().view(_B, 2, N, _L) + Ds = _Ds.view(-1) + delta_bias = dt_bias.view(-1).to(torch.float) + + if force_fp32: + xs = xs.to(torch.float) + dts = dts.to(xs.dtype) + Bs = Bs.to(xs.dtype) + Cs = Cs.to(xs.dtype) + + ys: torch.Tensor = selective_scan( + xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus + ).view(_B, 2, -1, _L) + return ys + + As = -self.A_logs.to(torch.float).exp().view(4, -1, N) + x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan + y_row = scan_rowcol( + x, + proj_weight = self.x_proj_weight.view(4, -1, D)[:2].contiguous(), + proj_bias = (x_proj_bias.view(4, -1)[:2].contiguous() if x_proj_bias is not None else None), + dt_weight = self.dt_projs_weight.view(4, D, -1)[:2].contiguous(), + dt_bias = (self.dt_projs_bias.view(4, -1)[:2].contiguous() if self.dt_projs_bias is not None else None), + _As = As[:2].contiguous().view(-1, N), + _Ds = self.Ds.view(4, -1)[:2].contiguous().view(-1), + width=True, + ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3) # (B,C,H,W) + y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan + y_col = scan_rowcol( + y_row, + proj_weight = self.x_proj_weight.view(4, -1, D)[2:].contiguous().to(y_row.dtype), + proj_bias = (x_proj_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if x_proj_bias is not None else None), + dt_weight = self.dt_projs_weight.view(4, D, -1)[2:].contiguous().to(y_row.dtype), + dt_bias = (self.dt_projs_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if self.dt_projs_bias is not None else None), + _As = As[2:].contiguous().view(-1, N), + _Ds = self.Ds.view(4, -1)[2:].contiguous().view(-1), + width=False, + ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1) + y = y_col + else: + x_proj_bias = getattr(self, "x_proj_bias", None) + xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch) + if no_einsum: + x_dbl = F.conv1d(xs.view(B, -1, L), self.x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K) + dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2) + dts = F.conv1d(dts.contiguous().view(B, -1, L), self.dt_projs_weight.view(K * D, -1, 1), groups=K) + else: + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) + if x_proj_bias is not None: + x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) + + xs = xs.view(B, -1, L) + dts = dts.contiguous().view(B, -1, L) + As = -self.A_logs.to(torch.float).exp() # (k * c, d_state) + Ds = self.Ds.to(torch.float) # (K * c) + Bs = Bs.contiguous().view(B, K, N, L) + Cs = Cs.contiguous().view(B, K, N, L) + delta_bias = self.dt_projs_bias.view(-1).to(torch.float) + + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + ys: torch.Tensor = selective_scan( + xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus + ).view(B, K, -1, H, W) + + y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch) + + if getattr(self, "__DEBUG__", False): + setattr(self, "__data__", dict( + A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds, + us=xs, dts=dts, delta_bias=delta_bias, + ys=ys, y=y, H=H, W=W, + )) + + y = y.view(B, -1, H, W) + if not channel_first: + y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C) + y = out_norm(y) + + return y.to(x.dtype) + + def forwardv2(self, x: torch.Tensor, **kwargs): + x = self.in_proj(x) + if not self.disable_z: + x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d) + if not self.disable_z_act: + z = self.act(z) + if not self.channel_first: + x = x.permute(0, 3, 1, 2).contiguous() + if self.with_dconv: + x = self.conv2d(x) # (b, d, h, w) + x = self.act(x) + y = self.forward_core(x) + y = self.out_act(y) + if not self.disable_z: + y = y * z + out = self.dropout(self.out_proj(y)) + return out + + @staticmethod + def get_outnorm(forward_type="", d_inner=192, channel_first=True): + def checkpostfix(tag, value): + ret = value[-len(tag):] == tag + if ret: + value = value[:-len(tag)] + return ret, value + + LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm + + out_norm_none, forward_type = checkpostfix("_onnone", forward_type) + out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type) + out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type) + out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type) + out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type) + + out_norm = nn.Identity() + if out_norm_none: + out_norm = nn.Identity() + elif out_norm_cnorm: + out_norm = nn.Sequential( + LayerNorm(d_inner), + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + ) + elif out_norm_dwconv3: + out_norm = nn.Sequential( + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + ) + elif out_norm_softmax: + out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1)) + elif out_norm_sigmoid: + out_norm = nn.Sigmoid() + else: + out_norm = LayerNorm(d_inner) + + return out_norm, forward_type + + @staticmethod + def checkpostfix(tag, value): + ret = value[-len(tag):] == tag + if ret: + value = value[:-len(tag)] + return ret, value + + +# support: xv1a,xv2a,xv3a; +# postfix: _cpos;_ocov;_ocov2;_ca,_ca1;_act;_mul;_onsigmoid,_onsoftmax,_ondwconv3,_onnone; +class SS2Dv3: + def __initxv__( + self, + # basic dims =========== + d_model=96, + d_state=16, + ssm_ratio=2.0, + dt_rank="auto", + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v0", + # ====================== + forward_type="v2", + channel_first=False, + # ====================== + **kwargs, + ): + super().__init__() + d_inner = int(ssm_ratio * d_model) + dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + self.channel_first = channel_first + self.d_state = d_state + self.dt_rank = dt_rank + self.d_inner = d_inner + k_group = 4 + self.with_dconv = d_conv > 1 + Linear = Linear2d if channel_first else nn.Linear + self.forward = self.forwardxv + + # tags for forward_type ============================== + checkpostfix = SS2Dv2.checkpostfix + self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, channel_first) + self.omul, forward_type = checkpostfix("_mul", forward_type) + self.oact, forward_type = checkpostfix("_act", forward_type) + self.f_omul = nn.Identity() if self.omul else None + self.out_act = nn.GELU() if self.oact else nn.Identity() + + mode = forward_type[:4] + assert mode in ["xv1a", "xv2a", "xv3a"] + + self.forward = partial(self.forwardxv, mode=mode) + self.dts_dim = dict(xv1a=self.dt_rank, xv2a=self.d_inner, xv3a=4 * self.dt_rank)[mode] + d_inner_all = d_inner + self.dts_dim + 8 * d_state + self.in_proj = Linear(d_model, d_inner_all, bias=bias) + + # conv ======================================= + self.cpos = False + self.iconv = False + self.oconv = False + self.oconv2 = False + if self.with_dconv: + cact, forward_type = checkpostfix("_ca", forward_type) + cact1, forward_type = checkpostfix("_ca1", forward_type) + self.cact = nn.SiLU() if cact else nn.Identity() + self.cact = nn.GELU() if cact1 else self.cact + + self.oconv2, forward_type = checkpostfix("_ocov2", forward_type) + self.oconv, forward_type = checkpostfix("_ocov", forward_type) + self.cpos, forward_type = checkpostfix("_cpos", forward_type) + self.iconv = (not self.oconv) and (not self.oconv2) + + if self.iconv: + self.conv2d = nn.Conv2d( + in_channels=d_model, + out_channels=d_model, + groups=d_model, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + ) + if self.oconv: + self.oconv2d = nn.Conv2d( + in_channels=d_inner, + out_channels=d_inner, + groups=d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + ) + if self.oconv2: + self.conv2d = nn.Conv2d( + in_channels=d_inner_all, + out_channels=d_inner_all, + groups=d_inner_all, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + ) + + # out proj ======================================= + self.out_proj = Linear(d_inner, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + if initialize in ["v0"]: + self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D( + d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4, + ) + elif initialize in ["v1"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) + self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) + self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) + elif initialize in ["v2"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) + self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank))) + self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner))) + + + if forward_type.startswith("xv2"): + del self.dt_projs_weight + self.dt_projs_weight = None + + def forwardxv(self, x: torch.Tensor, **kwargs): + B, (H, W) = x.shape[0], (x.shape[2:4] if self.channel_first else x.shape[1:3]) + L = H * W + force_fp32 = False + delta_softplus = True + out_norm = self.out_norm + to_dtype = True + + to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) + + def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus): + return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex=True, backend=None) + + if self.iconv: + x = self.cact(self.conv2d(x)) # (b, d, h, w) + elif self.cpos: + x = x + self.conv2d(x) # (b, d, h, w) + + x = self.in_proj(x) + + if self.oconv2: + x = self.conv2d(x) # (b, d, h, w) + + us, dts, Bs, Cs = x.split([self.d_inner, self.dts_dim, 4 * self.d_state, 4 * self.d_state], dim=(1 if self.channel_first else -1)) + + _us = us + # Bs, Cs = Bs.view(B, H, W, 4, -1), Cs.view(B, H, W, 4, -1) + # Bs, Cs = Bs.view(B, 4, -1, H, W), Cs.view(B, 4, -1, H, W) + us = cross_scan_fn(us.contiguous(), in_channel_first=self.channel_first, out_channel_first=True).view(B, -1, L) + Bs = cross_scan_fn(Bs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L) + Cs = cross_scan_fn(Cs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L) + dts = cross_scan_fn(dts.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=(self.dts_dim == 4 * self.dt_rank)).view(B, L, -1) + if self.dts_dim == self.dt_rank: + dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4) + elif self.dts_dim == 4 * self.dt_rank: + dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4) + + As = -self.A_logs.to(torch.float).exp() # (k * c, d_state) + Ds = self.Ds.to(torch.float) # (K * c) + delta_bias = self.dt_projs_bias.view(-1).to(torch.float) # (K * c) + + if force_fp32: + us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs) + + ys: torch.Tensor = selective_scan( + us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus + ).view(B, 4, -1, H, W) + y: torch.Tensor = cross_merge_fn(ys.contiguous(), in_channel_first=self.channel_first, out_channel_first=True) + y = y.view(B, -1, H, W) if self.channel_first else y.view(B, H, W, -1) + y = out_norm(y) + + if getattr(self, "__DEBUG__", False): + setattr(self, "__data__", dict( + A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds, + us=us, dts=dts, delta_bias=delta_bias, + ys=ys, y=y, + )) + + y = (y.to(x.dtype) if to_dtype else y) + + y = self.out_act(y) + + if self.omul: + y = y * _us + + if self.oconv: + y = y + self.cact(self.oconv2d(_us)) + + out = self.dropout(self.out_proj(y)) + return out + + +# mamba2 support ================================ +class SS2Dm0: + def __initm0__( + self, + # basic dims =========== + d_model=96, + d_state=16, # now with mamba2, dstate should be bigger... + ssm_ratio=2.0, + dt_rank="auto", + act_layer=nn.GELU, + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v2", + # ====================== + forward_type="m0", + # ====================== + with_initial_state=False, + # ====================== + **kwargs, + ): + factory_kwargs = {"device": None, "dtype": None} + super().__init__() + d_inner = int(ssm_ratio * d_model) + dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + assert d_inner % dt_rank == 0 + self.with_dconv = d_conv > 1 + Linear = nn.Linear + self.forward = self.forwardm0 + + # tags for forward_type ============================== + checkpostfix = SS2Dv2.checkpostfix + self.disable_force32, forward_type = checkpostfix("_no32", forward_type) + self.oact, forward_type = checkpostfix("_oact", forward_type) + self.disable_z, forward_type = checkpostfix("_noz", forward_type) + self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type) + self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False) + + # forward_type debug ======================================= + FORWARD_TYPES = dict( + m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state), + ) + self.forward_core = FORWARD_TYPES.get(forward_type, None) + k_group = 4 + + # in proj ======================================= + d_proj = d_inner if self.disable_z else (d_inner * 2) + self.in_proj = Linear(d_model, d_proj, bias=bias) + self.act: nn.Module = act_layer() + + # conv ======================================= + if self.with_dconv: + self.conv2d = nn.Sequential( + Permute(0, 3, 1, 2), + nn.Conv2d( + in_channels=d_inner, + out_channels=d_inner, + groups=d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ), + Permute(0, 2, 3, 1), + ) + + # x proj ============================ + self.x_proj = [ + nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False) + for _ in range(k_group) + ] + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) + del self.x_proj + + # out proj ======================================= + self.out_act = nn.GELU() if self.oact else nn.Identity() + self.out_proj = Linear(d_inner, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() + + if initialize in ["v1"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank)))) + self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430 + elif initialize in ["v2"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank)))) + self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank))) + + # init state ============================ + self.initial_state = None + if with_initial_state: + self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), requires_grad=False) + + def forward_corem0( + self, + x: torch.Tensor=None, + # ============================== + force_fp32=False, # True: input fp32 + chunk_size = 64, + dstate = 64, + # ============================== + selective_scan_backend = None, + scan_mode = "cross2d", + scan_force_torch = False, + # ============================== + **kwargs, + ): + assert scan_mode in ["unidi", "bidi", "cross2d"] + assert selective_scan_backend in [None, "triton", "torch"] + x_proj_bias = getattr(self, "x_proj_bias", None) + to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) + + N = dstate + B, H, W, RD = x.shape + K, R = self.A_logs.shape + K, R, D = self.Ds.shape + assert RD == R * D + L = H * W + KR = K * R + _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode] + + initial_state = None + if self.initial_state is not None: + assert self.initial_state.shape[-1] == dstate + initial_state = self.initial_state.detach().repeat(B, 1, 1, 1) + xs = cross_scan_fn(x.view(B, H, W, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D) + x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight) + if x_proj_bias is not None: + x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1) + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3) + xs = xs.contiguous().view(B, L, KR, D) + dts = dts.contiguous().view(B, L, KR) + Bs = Bs.contiguous().view(B, L, K, N) + Cs = Cs.contiguous().view(B, L, K, N) + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + As = -self.A_logs.to(torch.float).exp().view(KR) + Ds = self.Ds.to(torch.float).view(KR, D) + dt_bias = self.dt_projs_bias.view(KR) + + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + ys, final_state = selective_scan_chunk_fn( + xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias, + initial_states=initial_state, dt_softplus=True, return_final_states=True, + backend=selective_scan_backend, + ) + y: torch.Tensor = cross_merge_fn(ys.view(B, H, W, K, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) + + if getattr(self, "__DEBUG__", False): + setattr(self, "__data__", dict( + A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds, + us=xs, dts=dts, delta_bias=self.dt_projs_bias, + initial_state=self.initial_state, final_satte=final_state, + ys=ys, y=y, H=H, W=W, + )) + if self.initial_state is not None: + self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False) + + y = self.out_norm(y.view(B, H, W, -1)) + + return y.to(x.dtype) + + def forwardm0(self, x: torch.Tensor, **kwargs): + x = self.in_proj(x) + if not self.disable_z: + x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d) + if not self.disable_z_act: + z = self.act(z) + if self.with_dconv: + x = self.conv2d(x) # (b, d, h, w) + x = self.act(x) + y = self.forward_core(x) + y = self.out_act(y) + if not self.disable_z: + y = y * z + out = self.dropout(self.out_proj(y)) + return out + + +class SS2D(nn.Module, SS2Dv0, SS2Dv2, SS2Dv3, SS2Dm0): + def __init__( + self, + # basic dims =========== + d_model=96, + d_state=16, + ssm_ratio=2.0, + dt_rank="auto", + act_layer=nn.SiLU, + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v0", + # ====================== + forward_type="v5", + channel_first=False, + # ====================== + **kwargs, + ): + super().__init__() + kwargs.update( + d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank, + act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias, + dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor, + initialize=initialize, forward_type=forward_type, channel_first=channel_first, + ) + if forward_type in ["v0", "v0seq"]: + self.__initv0__(seq=("seq" in forward_type), **kwargs) + elif forward_type.startswith("xv"): + self.__initxv__(**kwargs) + elif forward_type.startswith("m"): + self.__initm0__(**kwargs) + else: + self.__initv2__(**kwargs) + + +# ===================================================== +class VSSBlock(nn.Module): + def __init__( + self, + hidden_dim: int = 0, + drop_path: float = 0, + norm_layer: nn.Module = nn.LayerNorm, + channel_first=False, + # ============================= + ssm_d_state: int = 16, + ssm_ratio=1, + ssm_dt_rank: Any = "auto", + ssm_act_layer=nn.SiLU, + ssm_conv: int = 3, + ssm_conv_bias=True, + ssm_drop_rate: float = 0, + ssm_init="v0", + forward_type="v05_noz", + # ============================= + mlp_ratio=4.0, + mlp_act_layer=nn.GELU, + mlp_drop_rate: float = 0.0, + gmlp=False, + # ============================= + use_checkpoint: bool = False, + post_norm: bool = False, + **kwargs, + ): + super().__init__() + self.ssm_branch = ssm_ratio > 0 + self.mlp_branch = mlp_ratio > 0 + self.use_checkpoint = use_checkpoint + self.post_norm = post_norm + + if self.ssm_branch: + self.norm = norm_layer(hidden_dim) + self.op = SS2D( + d_model=hidden_dim, + d_state=ssm_d_state, + ssm_ratio=ssm_ratio, + dt_rank=ssm_dt_rank, + act_layer=ssm_act_layer, + # ========================== + d_conv=ssm_conv, + conv_bias=ssm_conv_bias, + # ========================== + dropout=ssm_drop_rate, + # bias=False, + # ========================== + # dt_min=0.001, + # dt_max=0.1, + # dt_init="random", + # dt_scale="random", + # dt_init_floor=1e-4, + initialize=ssm_init, + # ========================== + forward_type=forward_type, + channel_first=channel_first, + ) + + self.drop_path = DropPath(drop_path) + + if self.mlp_branch: + _MLP = Mlp if not gmlp else gMlp + self.norm2 = norm_layer(hidden_dim) + mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first) + + def _forward(self, input: torch.Tensor): + x = input + if self.ssm_branch: + if self.post_norm: + x = x + self.drop_path(self.norm(self.op(x))) + else: + x = x + self.drop_path(self.op(self.norm(x))) + if self.mlp_branch: + if self.post_norm: + x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN + else: + x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN + return x + + def forward(self, input: torch.Tensor): + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, input) + else: + return self._forward(input) + + +class VSSM(nn.Module): + def __init__( + self, + patch_size=4, + in_chans=3, + num_classes=1000, + depths=[2, 2, 9, 2], + dims=[96, 192, 384, 768], + # ========================= + ssm_d_state=16, + ssm_ratio=2.0, + ssm_dt_rank="auto", + ssm_act_layer="silu", + ssm_conv=3, + ssm_conv_bias=True, + ssm_drop_rate=0.0, + ssm_init="v0", + forward_type="v2", + # ========================= + mlp_ratio=4.0, + mlp_act_layer="gelu", + mlp_drop_rate=0.0, + gmlp=False, + # ========================= + drop_path_rate=0.1, + patch_norm=True, + norm_layer="LN", # "BN", "LN2D" + downsample_version: str = "v2", # "v1", "v2", "v3" + patchembed_version: str = "v1", # "v1", "v2" + use_checkpoint=False, + # ========================= + posembed=False, + imgsize=224, + **kwargs, + ): + super().__init__() + self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) + self.num_classes = num_classes + self.num_layers = len(depths) + if isinstance(dims, int): + dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] + self.num_features = dims[-1] + self.dims = dims + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + _NORMLAYERS = dict( + ln=nn.LayerNorm, + ln2d=LayerNorm2d, + bn=nn.BatchNorm2d, + ) + + _ACTLAYERS = dict( + silu=nn.SiLU, + gelu=nn.GELU, + relu=nn.ReLU, + sigmoid=nn.Sigmoid, + ) + + norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) + ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None) + mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None) + + self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None + + _make_patch_embed = dict( + v1=self._make_patch_embed, + v2=self._make_patch_embed_v2, + ).get(patchembed_version, None) + self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first) + + _make_downsample = dict( + v1=PatchMerging2D, + v2=self._make_downsample, + v3=self._make_downsample_v3, + none=(lambda *_, **_k: None), + ).get(downsample_version, None) + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + downsample = _make_downsample( + self.dims[i_layer], + self.dims[i_layer + 1], + norm_layer=norm_layer, + channel_first=self.channel_first, + ) if (i_layer < self.num_layers - 1) else nn.Identity() + + self.layers.append(self._make_layer( + dim = self.dims[i_layer], + drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + use_checkpoint=use_checkpoint, + norm_layer=norm_layer, + downsample=downsample, + channel_first=self.channel_first, + # ================= + ssm_d_state=ssm_d_state, + ssm_ratio=ssm_ratio, + ssm_dt_rank=ssm_dt_rank, + ssm_act_layer=ssm_act_layer, + ssm_conv=ssm_conv, + ssm_conv_bias=ssm_conv_bias, + ssm_drop_rate=ssm_drop_rate, + ssm_init=ssm_init, + forward_type=forward_type, + # ================= + mlp_ratio=mlp_ratio, + mlp_act_layer=mlp_act_layer, + mlp_drop_rate=mlp_drop_rate, + gmlp=gmlp, + )) + + self.classifier = nn.Sequential(OrderedDict( + norm=norm_layer(self.num_features), # B,H,W,C + permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()), + avgpool=nn.AdaptiveAvgPool2d(1), + flatten=nn.Flatten(1), + head=nn.Linear(self.num_features, num_classes), + )) + + self.apply(self._init_weights) + + @staticmethod + def _pos_embed(embed_dims, patch_size, img_size): + patch_height, patch_width = (img_size // patch_size, img_size // patch_size) + pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width)) + trunc_normal_(pos_embed, std=0.02) + return pos_embed + + def _init_weights(self, m: nn.Module): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # used in building optimizer + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + # used in building optimizer + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {} + + @staticmethod + def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False): + # if channel first, then Norm and Output are both channel_first + return nn.Sequential( + nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + (norm_layer(embed_dim) if patch_norm else nn.Identity()), + ) + + @staticmethod + def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False): + # if channel first, then Norm and Output are both channel_first + stride = patch_size // 2 + kernel_size = stride + 1 + padding = 1 + return nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding), + (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)), + (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), + (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)), + nn.GELU(), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + (norm_layer(embed_dim) if patch_norm else nn.Identity()), + ) + + @staticmethod + def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False): + # if channel first, then Norm and Output are both channel_first + return nn.Sequential( + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + norm_layer(out_dim), + ) + + @staticmethod + def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False): + # if channel first, then Norm and Output are both channel_first + return nn.Sequential( + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + norm_layer(out_dim), + ) + + @staticmethod + def _make_layer( + dim=96, + drop_path=[0.1, 0.1], + use_checkpoint=False, + norm_layer=nn.LayerNorm, + downsample=nn.Identity(), + channel_first=False, + # =========================== + ssm_d_state=16, + ssm_ratio=2.0, + ssm_dt_rank="auto", + ssm_act_layer=nn.SiLU, + ssm_conv=3, + ssm_conv_bias=True, + ssm_drop_rate=0.0, + ssm_init="v0", + forward_type="v2", + # =========================== + mlp_ratio=4.0, + mlp_act_layer=nn.GELU, + mlp_drop_rate=0.0, + gmlp=False, + **kwargs, + ): + # if channel first, then Norm and Output are both channel_first + depth = len(drop_path) + blocks = [] + for d in range(depth): + blocks.append(VSSBlock( + hidden_dim=dim, + drop_path=drop_path[d], + norm_layer=norm_layer, + channel_first=channel_first, + ssm_d_state=ssm_d_state, + ssm_ratio=ssm_ratio, + ssm_dt_rank=ssm_dt_rank, + ssm_act_layer=ssm_act_layer, + ssm_conv=ssm_conv, + ssm_conv_bias=ssm_conv_bias, + ssm_drop_rate=ssm_drop_rate, + ssm_init=ssm_init, + forward_type=forward_type, + mlp_ratio=mlp_ratio, + mlp_act_layer=mlp_act_layer, + mlp_drop_rate=mlp_drop_rate, + gmlp=gmlp, + use_checkpoint=use_checkpoint, + )) + + return nn.Sequential(OrderedDict( + blocks=nn.Sequential(*blocks,), + downsample=downsample, + )) + + def forward(self, x: torch.Tensor): + x = self.patch_embed(x) + out_features = [] + if self.pos_embed is not None: + pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed + x = x + pos_embed + out_features.append(x) + for layer in self.layers: + x = layer(x) + if len(out_features) < 2: + out_features.append(x) + x = self.classifier(x) + return x + + def flops(self, shape=(3, 224, 224), verbose=True): + # shape = self.__input_shape__[1:] + supported_ops={ + "aten::silu": None, # as relu is in _IGNORED_OPS + "aten::neg": None, # as relu is in _IGNORED_OPS + "aten::exp": None, # as relu is in _IGNORED_OPS + "aten::flip": None, # as permute is in _IGNORED_OPS + # "prim::PythonOp.CrossScan": None, + # "prim::PythonOp.CrossMerge": None, + "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose), + } + + model = copy.deepcopy(self) + model.cuda().eval() + + input = torch.randn((1, *shape), device=next(model.parameters()).device) + params = parameter_count(model)[""] + Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) + + del model, input + return sum(Gflops.values()) * 1e9 + return f"params {params} GFLOPs {sum(Gflops.values())}" + + # used to load ckpt from previous training code + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + + def check_name(src, state_dict: dict = state_dict, strict=False): + if strict: + if prefix + src in list(state_dict.keys()): + return True + else: + key = prefix + src + for k in list(state_dict.keys()): + if k.startswith(key): + return True + return False + + def change_name(src, dst, state_dict: dict = state_dict, strict=False): + if strict: + if prefix + src in list(state_dict.keys()): + state_dict[prefix + dst] = state_dict[prefix + src] + state_dict.pop(prefix + src) + else: + key = prefix + src + for k in list(state_dict.keys()): + if k.startswith(key): + new_k = prefix + dst + k[len(key):] + state_dict[new_k] = state_dict[k] + state_dict.pop(k) + + if check_name("pos_embed", strict=True): + srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"] + state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device) + + change_name("patch_embed.proj", "patch_embed.0") + change_name("patch_embed.norm", "patch_embed.2") + for i in range(100): + for j in range(100): + change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm") + change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op") + change_name("norm", "classifier.norm") + change_name("head", "classifier.head") + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +# compatible with openmmlab +class Backbone_VSSM(VSSM): + def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs): + kwargs.update(norm_layer=norm_layer) + super().__init__(**kwargs) + self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) + _NORMLAYERS = dict( + ln=nn.LayerNorm, + ln2d=LayerNorm2d, + bn=nn.BatchNorm2d, + ) + norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) + + self.out_indices = out_indices + for i in out_indices: + layer = norm_layer(self.dims[i]) + layer_name = f'outnorm{i}' + self.add_module(layer_name, layer) + + del self.classifier + self.load_pretrained(pretrained) + + def load_pretrained(self, ckpt=None, key="model"): + if ckpt is None: + return + + try: + _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) + print(f"Successfully load ckpt {ckpt}") + incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) + print(incompatibleKeys) + except Exception as e: + print(f"Failed loading checkpoint form {ckpt}: {e}") + + def forward(self, x): + def layer_forward(l, x): + x = l.blocks(x) + y = l.downsample(x) + return x, y + + x = self.patch_embed(x) + outs = [] + for i, layer in enumerate(self.layers): + o, x = layer_forward(layer, x) # (B, H, W, C) + if i in self.out_indices: + norm_layer = getattr(self, f'outnorm{i}') + out = norm_layer(o) + if not self.channel_first: + out = out.permute(0, 3, 1, 2) + outs.append(out.contiguous()) + + if len(self.out_indices) == 0: + return x + + return outs + + +# ===================================================== +def vanilla_vmamba_tiny(): + return VSSM( + depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v0", + mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v1", patchembed_version="v1", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vanilla_vmamba_small(): + return VSSM( + depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v0", + mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v1", patchembed_version="v1", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vanilla_vmamba_base(): + return VSSM( + depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v0", + mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v1", patchembed_version="v1", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +# ===================================================== +def vmamba_tiny_s2l5(channel_first=True): + return VSSM( + depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_small_s2l15(channel_first=True): + return VSSM( + depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_base_s2l15(channel_first=True): + return VSSM( + depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +# ===================================================== +def vmamba_tiny_s1l8(channel_first=True): + return VSSM( + depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_small_s1l20(channel_first=True): + return VSSM( + depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_base_s1l20(channel_first=True): + return VSSM( + depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +# mamba2 support ===================================================== +def vmamba_tiny_m2(): + return VSSM( + depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v2", forward_type="m0_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_small_m2(): + return VSSM( + depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v2", forward_type="m0_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_base_m2(): + return VSSM( + depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v2", forward_type="m0_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +if __name__ == "__main__": + model = vmamba_tiny_s1l8() + + # model = VSSM( + # depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2, + # patch_size=4, in_chans=3, num_classes=1000, + # ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + # ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + # ssm_init="v2", forward_type="m0_noz", + # mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + # patch_norm=True, norm_layer="ln", + # downsample_version="v3", patchembed_version="v2", + # use_checkpoint=False, posembed=False, imgsize=224, + # ) + # print(parameter_count(model)[""]) + # print(model.flops()) # wrong + # model.cuda().train() + model_weights_path = 'vssm1_tiny_0230s_ckpt_epoch_264.pth' + checkpoint = torch.load(model_weights_path, map_location='cpu') + # if 'model' in checkpoint: + # msg = model.load_state_dict(checkpoint['model'], strict=False) + # print(msg) + model.load_state_dict(checkpoint['model'], strict=False) + model.cuda().eval() + x = torch.randn(1, 3, 256, 256).cuda() + y, features = model(x) + print('finish') + + def bench(model): + import time + inp = torch.randn((128, 3, 224, 224)).cuda() + for _ in range(30): + model(inp) + torch.cuda.synchronize() + tim = time.time() + for _ in range(30): + model(inp) + torch.cuda.synchronize() + tim1 = time.time() - tim + + for _ in range(30): + model(inp).sum().backward() + torch.cuda.synchronize() + tim = time.time() + for _ in range(30): + model(inp).sum().backward() + torch.cuda.synchronize() + tim2 = time.time() - tim + + return tim1 / 30, tim2 / 30 + + # print(bench(model_ref)) + # print(bench(model)) + # + # breakpoint() + + diff --git a/rscd/models/backbones/lib_mamba/vmambanew.py b/rscd/models/backbones/lib_mamba/vmambanew.py new file mode 100644 index 0000000000000000000000000000000000000000..281ed6d4861e29757c84e86311b1151135a43525 --- /dev/null +++ b/rscd/models/backbones/lib_mamba/vmambanew.py @@ -0,0 +1,1581 @@ +import os +import time +import math +import copy +from functools import partial +from typing import Optional, Callable, Any +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_ +from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count + +DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" +# train speed is slower after enabling this opts. +# torch.backends.cudnn.enabled = True +# torch.backends.cudnn.benchmark = True +# torch.backends.cudnn.deterministic = True + +try: + from .csm_triton import cross_scan_fn, cross_merge_fn +except: + from csm_triton import cross_scan_fn, cross_merge_fn + +try: + from .csm_tritonk2 import cross_scan_fn_k2, cross_merge_fn_k2 + from .csm_tritonk2 import cross_scan_fn_k2_torch, cross_merge_fn_k2_torch +except: + from csm_tritonk2 import cross_scan_fn_k2, cross_merge_fn_k2 + from csm_tritonk2 import cross_scan_fn_k2_torch, cross_merge_fn_k2_torch + +try: + from .csms6s import selective_scan_fn, selective_scan_flop_jit +except: + from csms6s import selective_scan_fn, selective_scan_flop_jit + +# FLOPs counter not prepared fro mamba2 +# try: +# from .mamba2.ssd_minimal import selective_scan_chunk_fn +# except: +# from mamba2.ssd_minimal import selective_scan_chunk_fn + + +# ===================================================== +# we have this class as linear and conv init differ from each other +# this function enable loading from both conv2d or linear +class Linear2d(nn.Linear): + def forward(self, x: torch.Tensor): + # B, C, H, W = x.shape + return F.conv2d(x, self.weight[:, :, None, None], self.bias) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape) + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor): + x = x.permute(0, 2, 3, 1) + x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.permute(0, 3, 1, 2) + return x + + +class PatchMerging2D(nn.Module): + def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False): + super().__init__() + self.dim = dim + Linear = Linear2d if channel_first else nn.Linear + self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last + self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) + self.norm = norm_layer(4 * dim) + + @staticmethod + def _patch_merging_pad_channel_last(x: torch.Tensor): + H, W, _ = x.shape[-3:] + if (W % 2 != 0) or (H % 2 != 0): + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + return x + + @staticmethod + def _patch_merging_pad_channel_first(x: torch.Tensor): + H, W = x.shape[-2:] + if (W % 2 != 0) or (H % 2 != 0): + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2] # ... H/2 W/2 + x1 = x[..., 1::2, 0::2] # ... H/2 W/2 + x2 = x[..., 0::2, 1::2] # ... H/2 W/2 + x3 = x[..., 1::2, 1::2] # ... H/2 W/2 + x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C + return x + + def forward(self, x): + x = self._patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) + + return x + + +class Permute(nn.Module): + def __init__(self, *args): + super().__init__() + self.args = args + + def forward(self, x: torch.Tensor): + return x.permute(*self.args) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., + channels_first=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + Linear = Linear2d if channels_first else nn.Linear + self.fc1 = Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class gMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., + channels_first=False): + super().__init__() + self.channel_first = channels_first + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + Linear = Linear2d if channels_first else nn.Linear + self.fc1 = Linear(in_features, 2 * hidden_features) + self.act = act_layer() + self.fc2 = Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor): + x = self.fc1(x) + x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) + x = self.fc2(x * self.act(z)) + x = self.drop(x) + return x + + +class SoftmaxSpatial(nn.Softmax): + def forward(self, x: torch.Tensor): + if self.dim == -1: + B, C, H, W = x.shape + return super().forward(x.view(B, C, -1).contiguous()).view(B, C, H, W).contiguous() + elif self.dim == 1: + B, H, W, C = x.shape + return super().forward(x.view(B, -1, C).contiguous()).view(B, H, W, C).contiguous() + else: + raise NotImplementedError + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', torch.nn.BatchNorm2d(b)) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, + groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +# ===================================================== +class mamba_init: + @staticmethod + def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4): + dt_proj = nn.Linear(dt_rank, d_inner, bias=True) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = dt_rank ** -0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + # dt_proj.bias._no_reinit = True + + return dt_proj + + @staticmethod + def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): + # S4D real initialization + A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + if copies > 0: + A_log = A_log[None].repeat(copies, 1, 1).contiguous() + if merge: + A_log = A_log.flatten(0, 1) + A_log = nn.Parameter(A_log) + A_log._no_weight_decay = True + return A_log + + @staticmethod + def D_init(d_inner, copies=-1, device=None, merge=True): + # D "skip" parameter + D = torch.ones(d_inner, device=device) + if copies > 0: + D = D[None].repeat(copies, 1).contiguous() + if merge: + D = D.flatten(0, 1) + D = nn.Parameter(D) # Keep in fp32 + D._no_weight_decay = True + return D + + @classmethod + def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4): + # dt proj ============================ + dt_projs = [ + cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor) + for _ in range(k_group) + ] + dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank) + dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner) + del dt_projs + + # A, D ======================================= + A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) + Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D) + return A_logs, Ds, dt_projs_weight, dt_projs_bias + + +class SS2Dv2: + def __initv2__( + self, + # basic dims =========== + d_model=96, + d_state=16, + ssm_ratio=2.0, + dt_rank="auto", + act_layer=nn.SiLU, + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v0", + # ====================== + forward_type="v05", + channel_first=False, + # ====================== + k_group=4, + **kwargs, + ): + factory_kwargs = {"device": None, "dtype": None} + super().__init__() + d_inner = int(ssm_ratio * d_model) + dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + self.channel_first = channel_first + self.with_dconv = d_conv > 1 + Linear = Linear2d if channel_first else nn.Linear + self.forward = self.forwardv2 + + # tags for forward_type ============================== + checkpostfix = self.checkpostfix + self.disable_force32, forward_type = checkpostfix("_no32", forward_type) + self.oact, forward_type = checkpostfix("_oact", forward_type) + self.disable_z, forward_type = checkpostfix("_noz", forward_type) + self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type) + self.out_norm, forward_type = self.get_outnorm(forward_type, d_inner, channel_first) + + # forward_type debug ======================================= + FORWARD_TYPES = dict( + v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", + scan_force_torch=True), + v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"), + v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"), + v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d" + v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), + # selective_scan_backend="oflex", scan_mode="cross2d" + # =============================== + v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"), + v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"), + v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"), + # =============================== + v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"), + v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"), + ) + self.forward_core = FORWARD_TYPES.get(forward_type, None) + self.k_group = k_group + + # in proj ======================================= + d_proj = d_inner if self.disable_z else (d_inner * 2) + self.in_proj = Conv2d_BN(d_model, d_proj) + # self.in_proj = Linear(d_model, d_proj, bias=bias) + self.act: nn.Module = act_layer() + + # conv ======================================= + if self.with_dconv: + self.conv2d = nn.Conv2d( + in_channels=d_inner, + out_channels=d_inner, + groups=d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + + # x proj ============================ + self.x_proj = [ + nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False) + # torch.nn.Conv2d(d_inner, (dt_rank + d_state * 2), 1, bias=False) + for _ in range(k_group) + ] + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) + del self.x_proj + + # out proj ======================================= + self.out_act = nn.GELU() if self.oact else nn.Identity() + self.out_proj = Conv2d_BN(d_inner, d_model) + # self.out_proj = Linear(d_inner, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() + + if initialize in ["v0"]: + self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D( + d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=k_group, + ) + elif initialize in ["v1"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) + self.A_logs = nn.Parameter( + torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((k_group, d_inner, dt_rank))) # 0.1 is added in 0430 + self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, d_inner))) # 0.1 is added in 0430 + elif initialize in ["v2"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) + self.A_logs = nn.Parameter( + torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank))) + self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner))) + + def forward_corev2( + self, + x: torch.Tensor = None, + # ============================== + force_fp32=False, # True: input fp32 + # ============================== + ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input + no_einsum=False, # replace einsum with linear or conv1d to raise throughput + # ============================== + selective_scan_backend=None, + # ============================== + scan_mode="cross2d", + scan_force_torch=False, + # ============================== + **kwargs, + ): + assert scan_mode in ["unidi", "bidi", "cross2d", "cascade2d"] + assert selective_scan_backend in [None, "oflex", "core", "mamba", "torch"] + delta_softplus = True + out_norm = self.out_norm + channel_first = self.channel_first + to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) + + B, D, H, W = x.shape + D, N = self.A_logs.shape + K, D, R = self.dt_projs_weight.shape + L = H * W + _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode] + + def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): + # print(u.device) + # print(selective_scan_backend) + if u.device == torch.device("cpu"): + selective_scan_backend = "torch" + else: + selective_scan_backend = "oflex" + return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, + backend=selective_scan_backend) + + if _scan_mode == 3: + x_proj_bias = getattr(self, "x_proj_bias", None) + + def scan_rowcol( + x: torch.Tensor, + proj_weight: torch.Tensor, + proj_bias: torch.Tensor, + dt_weight: torch.Tensor, + dt_bias: torch.Tensor, # (2*c) + _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state) + _Ds: torch.Tensor, + width=True, + ): + # x: (B, D, H, W) + # proj_weight: (2 * D, (R+N+N)) + XB, XD, XH, XW = x.shape + if width: + _B, _D, _L = XB * XH, XD, XW + xs = x.permute(0, 2, 1, 3).contiguous() + else: + _B, _D, _L = XB * XW, XD, XH + xs = x.permute(0, 3, 1, 2).contiguous() + xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W) + if no_einsum: + x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), + bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2) + dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2) + dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2) + else: + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight) + if x_proj_bias is not None: + x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight) + + xs = xs.view(_B, -1, _L) + dts = dts.contiguous().view(_B, -1, _L) + As = _As.view(-1, N).to(torch.float) + Bs = Bs.contiguous().view(_B, 2, N, _L) + Cs = Cs.contiguous().view(_B, 2, N, _L) + Ds = _Ds.view(-1) + delta_bias = dt_bias.view(-1).to(torch.float) + + if force_fp32: + xs = xs.to(torch.float) + dts = dts.to(xs.dtype) + Bs = Bs.to(xs.dtype) + Cs = Cs.to(xs.dtype) + + ys: torch.Tensor = selective_scan( + xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus + ).view(_B, 2, -1, _L) + return ys + + As = -self.A_logs.to(torch.float).exp().view(self.k_group, -1, N).contiguous() + x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, + 2).contiguous() # added0510 to avoid nan + y_row = scan_rowcol( + x, + proj_weight=self.x_proj_weight.view(self.k_group, -1, D)[:2].contiguous(), + proj_bias=(x_proj_bias.view(self.k_group, -1)[:2].contiguous() if x_proj_bias is not None else None), + dt_weight=self.dt_projs_weight.view(self.k_group, D, -1)[:2].contiguous(), + dt_bias=(self.dt_projs_bias.view(self.k_group, -1)[ + :2].contiguous() if self.dt_projs_bias is not None else None), + _As=As[:2].contiguous().view(-1, N), + _Ds=self.Ds.view(self.k_group, -1)[:2].contiguous().view(-1), + width=True, + ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3).contiguous() # (B,C,H,W) + y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, + 2).contiguous() # added0510 to avoid nan + y_col = scan_rowcol( + y_row, + proj_weight=self.x_proj_weight.view(self.k_group, -1, D)[2:].contiguous().to(y_row.dtype), + proj_bias=( + x_proj_bias.view(self.k_group, -1)[2:].contiguous().to( + y_row.dtype) if x_proj_bias is not None else None), + dt_weight=self.dt_projs_weight.view(self.k_group, D, -1)[2:].contiguous().to(y_row.dtype), + dt_bias=(self.dt_projs_bias.view(self.k_group, -1)[2:].contiguous().to( + y_row.dtype) if self.dt_projs_bias is not None else None), + _As=As[2:].contiguous().view(-1, N), + _Ds=self.Ds.view(self.k_group, -1)[2:].contiguous().view(-1), + width=False, + ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1).contiguous() + y = y_col + else: + x_proj_bias = getattr(self, "x_proj_bias", None) + if self.k_group == 4: + xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, + force_torch=scan_force_torch) + else: + xs = cross_scan_fn_k2(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, + force_torch=scan_force_torch) + if no_einsum: + x_dbl = F.conv1d(xs.view(B, -1, L).contiguous(), self.x_proj_weight.view(-1, D, 1).contiguous(), + bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K) + dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L).contiguous(), [R, N, N], dim=2) + dts = F.conv1d(dts.contiguous().view(B, -1, L).contiguous(), self.dt_projs_weight.view(K * D, -1, 1).contiguous(), groups=K) + else: + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) + if x_proj_bias is not None: + x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1).contiguous() + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) + + xs = xs.view(B, -1, L).contiguous() + dts = dts.contiguous().view(B, -1, L).contiguous() + As = -self.A_logs.to(torch.float).exp() # (k * c, d_state) + Ds = self.Ds.to(torch.float) # (K * c) + Bs = Bs.contiguous().view(B, K, N, L).contiguous() + Cs = Cs.contiguous().view(B, K, N, L).contiguous() + delta_bias = self.dt_projs_bias.view(-1).contiguous().to(torch.float) + + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + ys: torch.Tensor = selective_scan( + xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus + ).view(B, K, -1, H, W).contiguous() + + if self.k_group == 4: + y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, + force_torch=scan_force_torch) + else: + y: torch.Tensor = cross_merge_fn_k2(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, + force_torch=scan_force_torch) + + if getattr(self, "__DEBUG__", False): + setattr(self, "__data__", dict( + A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds, + us=xs, dts=dts, delta_bias=delta_bias, + ys=ys, y=y, H=H, W=W, + )) + + y = y.view(B, -1, H, W).contiguous() + if not channel_first: + y = y.view(B, -1, H * W).contiguous().transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1).contiguous() # (B, L, C) + y = out_norm(y) + + return y.to(x.dtype) + + def forwardv2(self, x: torch.Tensor, **kwargs): + x = self.in_proj(x) + x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d) + z = self.act(z) + x = self.conv2d(x) # (b, d, h, w) + x = self.act(x) + y = self.forward_core(x) + y = self.out_act(y) + y = y * z + out = self.dropout(self.out_proj(y)) + return out + + @staticmethod + def get_outnorm(forward_type="", d_inner=192, channel_first=True): + def checkpostfix(tag, value): + ret = value[-len(tag):] == tag + if ret: + value = value[:-len(tag)] + return ret, value + + LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm + + out_norm_none, forward_type = checkpostfix("_onnone", forward_type) + out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type) + out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type) + out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type) + out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type) + + out_norm = nn.Identity() + if out_norm_none: + out_norm = nn.Identity() + elif out_norm_cnorm: + out_norm = nn.Sequential( + LayerNorm(d_inner), + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + ) + elif out_norm_dwconv3: + out_norm = nn.Sequential( + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + ) + elif out_norm_softmax: + out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1)) + elif out_norm_sigmoid: + out_norm = nn.Sigmoid() + else: + out_norm = LayerNorm(d_inner) + + return out_norm, forward_type + + @staticmethod + def checkpostfix(tag, value): + ret = value[-len(tag):] == tag + if ret: + value = value[:-len(tag)] + return ret, value + + +# mamba2 support ================================ +class SS2Dm0: + def __initm0__( + self, + # basic dims =========== + d_model=96, + d_state=16, # now with mamba2, dstate should be bigger... + ssm_ratio=2.0, + dt_rank="auto", + act_layer=nn.GELU, + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v2", + # ====================== + forward_type="m0", + # ====================== + with_initial_state=False, + channel_first=False, + # ====================== + **kwargs, + ): + factory_kwargs = {"device": None, "dtype": None} + super().__init__() + d_inner = int(ssm_ratio * d_model) + dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank + assert d_inner % dt_rank == 0 + self.channel_first = channel_first + Linear = Linear2d if channel_first else nn.Linear + self.with_dconv = d_conv > 1 + self.forward = self.forwardm0 + + # tags for forward_type ============================== + checkpostfix = SS2Dv2.checkpostfix + self.disable_force32, forward_type = checkpostfix("_no32", forward_type) + self.oact, forward_type = checkpostfix("_oact", forward_type) + self.disable_z, forward_type = checkpostfix("_noz", forward_type) + self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type) + self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False) + + # forward_type debug ======================================= + FORWARD_TYPES = dict( + m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state), + ) + self.forward_core = FORWARD_TYPES.get(forward_type, None) + k_group = 4 + + # in proj ======================================= + d_proj = d_inner if self.disable_z else (d_inner * 2) + # self.in_proj = Linear(d_model, d_proj, bias=bias) + self.in_proj = Conv2d_BN(d_model, d_proj) + self.act: nn.Module = act_layer() + + # conv ======================================= + if self.with_dconv: + self.conv2d = nn.Conv2d( + in_channels=d_inner, + out_channels=d_inner, + groups=d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + + # x proj ============================ + self.x_proj = [ + nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False) + for _ in range(k_group) + ] + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) + del self.x_proj + + # out proj ======================================= + self.out_act = nn.GELU() if self.oact else nn.Identity() + # self.out_proj = Linear(d_inner, d_model, bias=bias) + self.out_proj = Conv2d_BN(d_inner, d_model) + self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() + + if initialize in ["v1"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank)))) + self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430 + elif initialize in ["v2"]: + # simple init dt_projs, A_logs, Ds + self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank)))) + self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 + self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank))) + + # init state ============================ + self.initial_state = None + if with_initial_state: + self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), + requires_grad=False) + + def forward_corem0( + self, + x: torch.Tensor = None, + # ============================== + force_fp32=False, # True: input fp32 + chunk_size=64, + dstate=64, + # ============================== + selective_scan_backend='torch', + scan_mode="cross2d", + scan_force_torch=False, + # ============================== + **kwargs, + ): + assert scan_mode in ["unidi", "bidi", "cross2d"] + assert selective_scan_backend in [None, "triton", "torch"] + x_proj_bias = getattr(self, "x_proj_bias", None) + to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) + + N = dstate + B, H, W, RD = x.shape + K, R = self.A_logs.shape + K, R, D = self.Ds.shape + assert RD == R * D + L = H * W + KR = K * R + _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode] + + initial_state = None + if self.initial_state is not None: + assert self.initial_state.shape[-1] == dstate + initial_state = self.initial_state.detach().repeat(B, 1, 1, 1) + xs = cross_scan_fn(x.view(B, H, W, RD).contiguous(), in_channel_first=False, out_channel_first=False, + scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D) + x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight) + if x_proj_bias is not None: + x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1) + dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3) + xs = xs.contiguous().view(B, L, KR, D).contiguous() + dts = dts.contiguous().view(B, L, KR).contiguous() + Bs = Bs.contiguous().view(B, L, K, N).contiguous() + Cs = Cs.contiguous().view(B, L, K, N).contiguous() + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + As = -self.A_logs.to(torch.float).exp().view(KR).contiguous() + Ds = self.Ds.to(torch.float).view(KR, D).contiguous() + dt_bias = self.dt_projs_bias.view(KR).contiguous() + + if force_fp32: + xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) + + ys, final_state = selective_scan_chunk_fn( + xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias, + initial_states=initial_state, dt_softplus=True, return_final_states=True, + backend=selective_scan_backend, + ) + y: torch.Tensor = cross_merge_fn(ys.contiguous().view(B, H, W, K, RD).contiguous(), in_channel_first=False, + out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) + + if getattr(self, "__DEBUG__", False): + setattr(self, "__data__", dict( + A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds, + us=xs, dts=dts, delta_bias=self.dt_projs_bias, + initial_state=self.initial_state, final_satte=final_state, + ys=ys, y=y, H=H, W=W, + )) + if self.initial_state is not None: + self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False) + + y = self.out_norm(y.view(B, H, W, -1).contiguous()) + + return y.to(x.dtype) + + def forwardm0(self, x: torch.Tensor, **kwargs): + x = self.in_proj(x) + if not self.disable_z: + x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d) + if not self.disable_z_act: + z = self.act(z) + if self.with_dconv: + x = self.conv2d(x) # (b, d, h, w) + x = self.act(x) + + y = self.forward_core(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + y = self.out_act(y) + if not self.disable_z: + y = y * z + out = self.dropout(self.out_proj(y)) + return out + + +class SS2D(nn.Module, SS2Dv2, SS2Dm0): + def __init__( + self, + # basic dims =========== + d_model=96, + d_state=16, + ssm_ratio=2.0, + dt_rank="auto", + act_layer=nn.SiLU, + # dwconv =============== + d_conv=3, # < 2 means no conv + conv_bias=True, + # ====================== + dropout=0.0, + bias=False, + # dt init ============== + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + initialize="v0", + # ====================== + forward_type="v5", + channel_first=False, + # ====================== + k_group=4, + **kwargs, + ): + super().__init__() + kwargs.update( + d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank, + act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias, + dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor, + initialize=initialize, forward_type=forward_type, channel_first=channel_first, k_group=k_group, + ) + if forward_type in ["v0", "v0seq"]: + self.__initv0__(seq=("seq" in forward_type), **kwargs) + elif forward_type.startswith("xv"): + self.__initxv__(**kwargs) + elif forward_type.startswith("m"): + self.__initm0__(**kwargs) + else: + self.__initv2__(**kwargs) + + +# ===================================================== +class VSSBlock(nn.Module): + def __init__( + self, + hidden_dim: int = 0, + drop_path: float = 0, + norm_layer: nn.Module = nn.LayerNorm, + channel_first=False, + # ============================= + ssm_d_state: int = 16, + ssm_ratio=1, + ssm_dt_rank: Any = "auto", + ssm_act_layer=nn.SiLU, + ssm_conv: int = 3, + ssm_conv_bias=True, + ssm_drop_rate: float = 0, + ssm_init="v0", + forward_type="v05_noz", + # ============================= + mlp_ratio=4.0, + mlp_act_layer=nn.GELU, + mlp_drop_rate: float = 0.0, + gmlp=False, + # ============================= + use_checkpoint: bool = False, + post_norm: bool = False, + **kwargs, + ): + super().__init__() + self.ssm_branch = ssm_ratio > 0 + self.mlp_branch = mlp_ratio > 0 + self.use_checkpoint = use_checkpoint + self.post_norm = post_norm + + if self.ssm_branch: + self.norm = norm_layer(hidden_dim) + self.op = SS2D( + d_model=hidden_dim, + d_state=ssm_d_state, + ssm_ratio=ssm_ratio, + dt_rank=ssm_dt_rank, + act_layer=ssm_act_layer, + # ========================== + d_conv=ssm_conv, + conv_bias=ssm_conv_bias, + # ========================== + dropout=ssm_drop_rate, + # bias=False, + # ========================== + # dt_min=0.001, + # dt_max=0.1, + # dt_init="random", + # dt_scale="random", + # dt_init_floor=1e-4, + initialize=ssm_init, + # ========================== + forward_type=forward_type, + channel_first=channel_first, + ) + + self.drop_path = DropPath(drop_path) + + if self.mlp_branch: + _MLP = Mlp if not gmlp else gMlp + self.norm2 = norm_layer(hidden_dim) + mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, + drop=mlp_drop_rate, channels_first=channel_first) + + def _forward(self, input: torch.Tensor): + x = input + if self.ssm_branch: + if self.post_norm: + x = x + self.drop_path(self.norm(self.op(x))) + else: + x = x + self.drop_path(self.op(self.norm(x))) + if self.mlp_branch: + if self.post_norm: + x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN + else: + x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN + return x + + def forward(self, input: torch.Tensor): + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, input) + else: + return self._forward(input) + + +class VSSM(nn.Module): + def __init__( + self, + patch_size=4, + in_chans=3, + num_classes=1000, + depths=[2, 2, 9, 2], + dims=[96, 192, 384, 768], + # ========================= + ssm_d_state=16, + ssm_ratio=2.0, + ssm_dt_rank="auto", + ssm_act_layer="silu", + ssm_conv=3, + ssm_conv_bias=True, + ssm_drop_rate=0.0, + ssm_init="v0", + forward_type="v2", + # ========================= + mlp_ratio=4.0, + mlp_act_layer="gelu", + mlp_drop_rate=0.0, + gmlp=False, + # ========================= + drop_path_rate=0.1, + patch_norm=True, + norm_layer="LN", # "BN", "LN2D" + downsample_version: str = "v2", # "v1", "v2", "v3" + patchembed_version: str = "v1", # "v1", "v2" + use_checkpoint=False, + # ========================= + posembed=False, + imgsize=224, + **kwargs, + ): + super().__init__() + self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) + self.num_classes = num_classes + self.num_layers = len(depths) + if isinstance(dims, int): + dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] + self.num_features = dims[-1] + self.dims = dims + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + _NORMLAYERS = dict( + ln=nn.LayerNorm, + ln2d=LayerNorm2d, + bn=nn.BatchNorm2d, + ) + + _ACTLAYERS = dict( + silu=nn.SiLU, + gelu=nn.GELU, + relu=nn.ReLU, + sigmoid=nn.Sigmoid, + ) + + norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) + ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None) + mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None) + + self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None + + _make_patch_embed = dict( + v1=self._make_patch_embed, + v2=self._make_patch_embed_v2, + ).get(patchembed_version, None) + self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, + channel_first=self.channel_first) + + _make_downsample = dict( + v1=PatchMerging2D, + v2=self._make_downsample, + v3=self._make_downsample_v3, + none=(lambda *_, **_k: None), + ).get(downsample_version, None) + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + downsample = _make_downsample( + self.dims[i_layer], + self.dims[i_layer + 1], + norm_layer=norm_layer, + channel_first=self.channel_first, + ) if (i_layer < self.num_layers - 1) else nn.Identity() + + self.layers.append(self._make_layer( + dim=self.dims[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + use_checkpoint=use_checkpoint, + norm_layer=norm_layer, + downsample=downsample, + channel_first=self.channel_first, + # ================= + ssm_d_state=ssm_d_state, + ssm_ratio=ssm_ratio, + ssm_dt_rank=ssm_dt_rank, + ssm_act_layer=ssm_act_layer, + ssm_conv=ssm_conv, + ssm_conv_bias=ssm_conv_bias, + ssm_drop_rate=ssm_drop_rate, + ssm_init=ssm_init, + forward_type=forward_type, + # ================= + mlp_ratio=mlp_ratio, + mlp_act_layer=mlp_act_layer, + mlp_drop_rate=mlp_drop_rate, + gmlp=gmlp, + )) + + self.classifier = nn.Sequential(OrderedDict( + norm=norm_layer(self.num_features), # B,H,W,C + permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()), + avgpool=nn.AdaptiveAvgPool2d(1), + flatten=nn.Flatten(1), + head=nn.Linear(self.num_features, num_classes), + )) + + self.apply(self._init_weights) + + @staticmethod + def _pos_embed(embed_dims, patch_size, img_size): + patch_height, patch_width = (img_size // patch_size, img_size // patch_size) + pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width)) + trunc_normal_(pos_embed, std=0.02) + return pos_embed + + def _init_weights(self, m: nn.Module): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # used in building optimizer + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + # used in building optimizer + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {} + + @staticmethod + def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, + channel_first=False): + # if channel first, then Norm and Output are both channel_first + return nn.Sequential( + nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + (norm_layer(embed_dim) if patch_norm else nn.Identity()), + ) + + @staticmethod + def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, + channel_first=False): + # if channel first, then Norm and Output are both channel_first + stride = patch_size // 2 + kernel_size = stride + 1 + padding = 1 + return nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding), + (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)), + (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), + (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)), + nn.GELU(), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + (norm_layer(embed_dim) if patch_norm else nn.Identity()), + ) + + @staticmethod + def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False): + # if channel first, then Norm and Output are both channel_first + return nn.Sequential( + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(dim, out_dim, kernel_size=2, stride=2), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + norm_layer(out_dim), + ) + + @staticmethod + def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False): + # if channel first, then Norm and Output are both channel_first + return nn.Sequential( + (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), + nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), + (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), + norm_layer(out_dim), + ) + + @staticmethod + def _make_layer( + dim=96, + drop_path=[0.1, 0.1], + use_checkpoint=False, + norm_layer=nn.LayerNorm, + downsample=nn.Identity(), + channel_first=False, + # =========================== + ssm_d_state=16, + ssm_ratio=2.0, + ssm_dt_rank="auto", + ssm_act_layer=nn.SiLU, + ssm_conv=3, + ssm_conv_bias=True, + ssm_drop_rate=0.0, + ssm_init="v0", + forward_type="v2", + # =========================== + mlp_ratio=4.0, + mlp_act_layer=nn.GELU, + mlp_drop_rate=0.0, + gmlp=False, + **kwargs, + ): + # if channel first, then Norm and Output are both channel_first + depth = len(drop_path) + blocks = [] + for d in range(depth): + blocks.append(VSSBlock( + hidden_dim=dim, + drop_path=drop_path[d], + norm_layer=norm_layer, + channel_first=channel_first, + ssm_d_state=ssm_d_state, + ssm_ratio=ssm_ratio, + ssm_dt_rank=ssm_dt_rank, + ssm_act_layer=ssm_act_layer, + ssm_conv=ssm_conv, + ssm_conv_bias=ssm_conv_bias, + ssm_drop_rate=ssm_drop_rate, + ssm_init=ssm_init, + forward_type=forward_type, + mlp_ratio=mlp_ratio, + mlp_act_layer=mlp_act_layer, + mlp_drop_rate=mlp_drop_rate, + gmlp=gmlp, + use_checkpoint=use_checkpoint, + )) + + return nn.Sequential(OrderedDict( + blocks=nn.Sequential(*blocks, ), + downsample=downsample, + )) + + def forward(self, x: torch.Tensor): + x = self.patch_embed(x) + out_features = [] + if self.pos_embed is not None: + pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed + x = x + pos_embed + out_features.append(x) + for layer in self.layers: + x = layer(x) + if len(out_features) < 2: + out_features.append(x) + x = self.classifier(x) + return x + + def flops(self, shape=(3, 224, 224), verbose=True): + # shape = self.__input_shape__[1:] + supported_ops = { + "aten::silu": None, # as relu is in _IGNORED_OPS + "aten::neg": None, # as relu is in _IGNORED_OPS + "aten::exp": None, # as relu is in _IGNORED_OPS + "aten::flip": None, # as permute is in _IGNORED_OPS + # "prim::PythonOp.CrossScan": None, + # "prim::PythonOp.CrossMerge": None, + "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose), + } + + model = copy.deepcopy(self) + model.cuda().eval() + + input = torch.randn((1, *shape), device=next(model.parameters()).device) + params = parameter_count(model)[""] + Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) + + del model, input + return sum(Gflops.values()) * 1e9 + return f"params {params} GFLOPs {sum(Gflops.values())}" + + # used to load ckpt from previous training code + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + + def check_name(src, state_dict: dict = state_dict, strict=False): + if strict: + if prefix + src in list(state_dict.keys()): + return True + else: + key = prefix + src + for k in list(state_dict.keys()): + if k.startswith(key): + return True + return False + + def change_name(src, dst, state_dict: dict = state_dict, strict=False): + if strict: + if prefix + src in list(state_dict.keys()): + state_dict[prefix + dst] = state_dict[prefix + src] + state_dict.pop(prefix + src) + else: + key = prefix + src + for k in list(state_dict.keys()): + if k.startswith(key): + new_k = prefix + dst + k[len(key):] + state_dict[new_k] = state_dict[k] + state_dict.pop(k) + + if check_name("pos_embed", strict=True): + srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"] + state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], + align_corners=False, mode="bicubic").to(srcEmb.device) + + change_name("patch_embed.proj", "patch_embed.0") + change_name("patch_embed.norm", "patch_embed.2") + for i in range(100): + for j in range(100): + change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm") + change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op") + change_name("norm", "classifier.norm") + change_name("head", "classifier.head") + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + +# compatible with openmmlab +class Backbone_VSSM(VSSM): + def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs): + kwargs.update(norm_layer=norm_layer) + super().__init__(**kwargs) + self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) + _NORMLAYERS = dict( + ln=nn.LayerNorm, + ln2d=LayerNorm2d, + bn=nn.BatchNorm2d, + ) + norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) + + self.out_indices = out_indices + for i in out_indices: + layer = norm_layer(self.dims[i]) + layer_name = f'outnorm{i}' + self.add_module(layer_name, layer) + + del self.classifier + self.load_pretrained(pretrained) + + def load_pretrained(self, ckpt=None, key="model"): + if ckpt is None: + return + + try: + _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) + print(f"Successfully load ckpt {ckpt}") + incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) + print(incompatibleKeys) + except Exception as e: + print(f"Failed loading checkpoint form {ckpt}: {e}") + + def forward(self, x): + def layer_forward(l, x): + x = l.blocks(x) + y = l.downsample(x) + return x, y + + x = self.patch_embed(x) + outs = [] + for i, layer in enumerate(self.layers): + o, x = layer_forward(layer, x) # (B, H, W, C) + if i in self.out_indices: + norm_layer = getattr(self, f'outnorm{i}') + out = norm_layer(o) + if not self.channel_first: + out = out.permute(0, 3, 1, 2) + outs.append(out.contiguous()) + + if len(self.out_indices) == 0: + return x + + return outs + + +# ===================================================== +def vanilla_vmamba_tiny(): + return VSSM( + depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v0", + mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v1", patchembed_version="v1", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vanilla_vmamba_small(): + return VSSM( + depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v0", + mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v1", patchembed_version="v1", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vanilla_vmamba_base(): + return VSSM( + depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v0", + mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v1", patchembed_version="v1", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +# ===================================================== +def vmamba_tiny_s2l5(channel_first=True): + return VSSM( + depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_small_s2l15(channel_first=True): + return VSSM( + depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_base_s2l15(channel_first=True): + return VSSM( + depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +# ===================================================== +def vmamba_tiny_s1l8(channel_first=True): + return VSSM( + depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_small_s1l20(channel_first=True): + return VSSM( + depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_base_s1l20(channel_first=True): + return VSSM( + depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v0", forward_type="v05_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +# mamba2 support ===================================================== +def vmamba_tiny_m2(): + return VSSM( + depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v2", forward_type="m0_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_small_m2(): + return VSSM( + depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v2", forward_type="m0_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +def vmamba_base_m2(): + return VSSM( + depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3, + patch_size=4, in_chans=3, num_classes=1000, + ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + ssm_init="v2", forward_type="m0_noz", + mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + patch_norm=True, norm_layer="ln", + downsample_version="v3", patchembed_version="v2", + use_checkpoint=False, posembed=False, imgsize=224, + ) + + +if __name__ == "__main__": + model = vmamba_tiny_s1l8() + + # model = VSSM( + # depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2, + # patch_size=4, in_chans=3, num_classes=1000, + # ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu", + # ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, + # ssm_init="v2", forward_type="m0_noz", + # mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, + # patch_norm=True, norm_layer="ln", + # downsample_version="v3", patchembed_version="v2", + # use_checkpoint=False, posembed=False, imgsize=224, + # ) + # print(parameter_count(model)[""]) + # print(model.flops()) # wrong + # model.cuda().train() + model_weights_path = 'vssm1_tiny_0230s_ckpt_epoch_264.pth' + checkpoint = torch.load(model_weights_path, map_location='cpu') + # if 'model' in checkpoint: + # msg = model.load_state_dict(checkpoint['model'], strict=False) + # print(msg) + model.load_state_dict(checkpoint['model'], strict=False) + model.cuda().eval() + x = torch.randn(1, 3, 256, 256).cuda() + y, features = model(x) + print('finish') + + + def bench(model): + import time + inp = torch.randn((128, 3, 224, 224)).cuda() + for _ in range(30): + model(inp) + torch.cuda.synchronize() + tim = time.time() + for _ in range(30): + model(inp) + torch.cuda.synchronize() + tim1 = time.time() - tim + + for _ in range(30): + model(inp).sum().backward() + torch.cuda.synchronize() + tim = time.time() + for _ in range(30): + model(inp).sum().backward() + torch.cuda.synchronize() + tim2 = time.time() - tim + + return tim1 / 30, tim2 / 30 + + # print(bench(model_ref)) + # print(bench(model)) + # + # breakpoint() + +