#pragma once namespace at { const unsigned kFullMask = 0xFFFFFFFF; template __device__ scalar_t warp_reduce(scalar_t value) { #pragma unroll for (int delta = 1; delta < warpSize; delta *= 2) #if __CUDACC_VER_MAJOR__ >= 9 value += __shfl_down_sync(kFullMask, value, delta); #else value += __shfl_down(value, delta); #endif return value; } template __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) { #if __CUDACC_VER_MAJOR__ >= 9 return __shfl_sync(kFullMask, value, lane_id); #else return __shfl(value, lane_id); #endif } } // namespace at