| |
|
|
| #pragma once |
|
|
| #include "gemm/loader.h" |
| #include "gemm/mma.h" |
| #include "gemm/params.h" |
| #include "gemm/transforms.h" |
| #include "gemm/utils.h" |
|
|
| using namespace metal; |
|
|
| |
| |
| |
|
|
| namespace mlx { |
| namespace steel { |
|
|
| template <bool M_aligned, bool N_aligned, bool K_aligned> |
| struct LoopAlignment {}; |
|
|
| template < |
| typename T, |
| typename U, |
| int BM, |
| int BN, |
| int BK, |
| int WM, |
| int WN, |
| bool transpose_a, |
| bool transpose_b, |
| bool MN_aligned, |
| bool K_aligned, |
| typename AccumType = typename AccumHelper<T>::accum_type, |
| typename Epilogue = TransformNone<U, AccumType>> |
| struct GEMMKernel { |
| STEEL_CONST short tgp_padding_a = 16 / sizeof(T); |
| STEEL_CONST short tgp_padding_b = 16 / sizeof(T); |
| STEEL_CONST short tgp_mem_size_a = |
| transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); |
| STEEL_CONST short tgp_mem_size_b = |
| transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); |
| STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; |
|
|
| STEEL_CONST short tgp_size = WM * WN * 32; |
|
|
| using loader_a_t = BlockLoader< |
| T, |
| transpose_a ? BK : BM, |
| transpose_a ? BM : BK, |
| transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, |
| !transpose_a, |
| tgp_size>; |
| using loader_b_t = BlockLoader< |
| T, |
| transpose_b ? BN : BK, |
| transpose_b ? BK : BN, |
| transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, |
| transpose_b, |
| tgp_size>; |
| using mma_t = BlockMMA< |
| T, |
| U, |
| BM, |
| BN, |
| BK, |
| WM, |
| WN, |
| transpose_a, |
| transpose_b, |
| transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, |
| transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, |
| AccumType, |
| Epilogue>; |
|
|
| |
| template <bool M_aligned, bool N_aligned, bool K_aligned_> |
| static METAL_FUNC void gemm_loop( |
| threadgroup T* As [[threadgroup(0)]], |
| threadgroup T* Bs [[threadgroup(1)]], |
| const int gemm_k_iterations, |
| thread loader_a_t& loader_a, |
| thread loader_b_t& loader_b, |
| thread mma_t& mma_op, |
| thread const short& tgp_bm, |
| thread const short& tgp_bn, |
| thread const short& lbk, |
| LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) { |
| |
| (void)l; |
|
|
| short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); |
|
|
| short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); |
|
|
| for (int k = 0; k < gemm_k_iterations; k++) { |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
| if (M_aligned) { |
| loader_a.load_unsafe(); |
| } else { |
| loader_a.load_safe(tile_dims_A); |
| } |
|
|
| if (N_aligned) { |
| loader_b.load_unsafe(); |
| } else { |
| loader_b.load_safe(tile_dims_B); |
| } |
|
|
| threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
| |
| mma_op.mma(As, Bs); |
|
|
| |
| loader_a.next(); |
| loader_b.next(); |
| } |
|
|
| if (!K_aligned_) { |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
| short2 tile_dims_A_last = |
| transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); |
| short2 tile_dims_B_last = |
| transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); |
|
|
| loader_a.load_safe(tile_dims_A_last); |
| loader_b.load_safe(tile_dims_B_last); |
|
|
| threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
| mma_op.mma(As, Bs); |
| } |
| } |
|
|
| |
| static METAL_FUNC void run( |
| const device T* A [[buffer(0)]], |
| const device T* B [[buffer(1)]], |
| device U* D [[buffer(2)]], |
| const constant GEMMParams* params [[buffer(3)]], |
| threadgroup T* As [[threadgroup(0)]], |
| threadgroup T* Bs [[threadgroup(1)]], |
| uint simd_lane_id [[thread_index_in_simdgroup]], |
| uint simd_group_id [[simdgroup_index_in_threadgroup]], |
| uint3 tid [[threadgroup_position_in_grid]], |
| uint3 lid [[thread_position_in_threadgroup]]) { |
| |
| (void)lid; |
|
|
| const int tid_y = ((tid.y) << params->swizzle_log) + |
| ((tid.x) & ((1 << params->swizzle_log) - 1)); |
| const int tid_x = (tid.x) >> params->swizzle_log; |
|
|
| if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { |
| return; |
| } |
|
|
| threadgroup_barrier(mem_flags::mem_none); |
|
|
| |
| const int c_row = tid_y * BM; |
| const int c_col = tid_x * BN; |
| const size_t c_row_long = size_t(c_row); |
| const size_t c_col_long = size_t(c_col); |
|
|
| A += transpose_a ? c_row_long : c_row_long * params->lda; |
| B += transpose_b ? c_col_long * params->ldb : c_col_long; |
| D += c_row_long * params->ldd + c_col_long; |
|
|
| |
| thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); |
| thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); |
|
|
| |
| thread mma_t mma_op(simd_group_id, simd_lane_id); |
|
|
| int gemm_k_iterations = params->gemm_k_iterations_aligned; |
|
|
| |
| |
| if (MN_aligned) { |
| for (int k = 0; k < gemm_k_iterations; k++) { |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
| loader_a.load_unsafe(); |
| loader_b.load_unsafe(); |
|
|
| threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
| |
| mma_op.mma(As, Bs); |
|
|
| |
| loader_a.next(); |
| loader_b.next(); |
| } |
|
|
| threadgroup_barrier(mem_flags::mem_none); |
|
|
| |
| if (!K_aligned) { |
| int lbk = params->K - params->gemm_k_iterations_aligned * BK; |
| short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); |
| short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); |
|
|
| loader_a.load_safe(tile_dims_A); |
| loader_b.load_safe(tile_dims_B); |
|
|
| threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
| mma_op.mma(As, Bs); |
| } |
|
|
| |
| mma_op.store_result(D, params->ldd); |
| return; |
|
|
| } |
| |
| |
| else { |
| short tgp_bm = min(BM, params->M - c_row); |
| short tgp_bn = min(BN, params->N - c_col); |
| short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; |
|
|
| if (tgp_bm == BM && tgp_bn == BN) { |
| gemm_loop<true, true, K_aligned>( |
| As, |
| Bs, |
| gemm_k_iterations, |
| loader_a, |
| loader_b, |
| mma_op, |
| tgp_bm, |
| tgp_bn, |
| leftover_bk); |
|
|
| mma_op.store_result(D, params->ldd); |
| return; |
|
|
| } else if (tgp_bn == BN) { |
| gemm_loop<false, true, K_aligned>( |
| As, |
| Bs, |
| gemm_k_iterations, |
| loader_a, |
| loader_b, |
| mma_op, |
| tgp_bm, |
| tgp_bn, |
| leftover_bk); |
|
|
| mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); |
| return; |
|
|
| } else if (tgp_bm == BM) { |
| gemm_loop<true, false, K_aligned>( |
| As, |
| Bs, |
| gemm_k_iterations, |
| loader_a, |
| loader_b, |
| mma_op, |
| tgp_bm, |
| tgp_bn, |
| leftover_bk); |
|
|
| mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); |
| return; |
|
|
| } else { |
| gemm_loop<false, false, K_aligned>( |
| As, |
| Bs, |
| gemm_k_iterations, |
| loader_a, |
| loader_b, |
| mma_op, |
| tgp_bm, |
| tgp_bn, |
| leftover_bk); |
|
|
| mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); |
| return; |
| } |
| } |
| } |
| }; |
|
|
| } |
| } |