File size: 623 Bytes
c810120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#pragma once

namespace at {

const unsigned kFullMask = 0xFFFFFFFF;

template <class scalar_t>
__device__ scalar_t warp_reduce(scalar_t value) {
#pragma unroll
    for (int delta = 1; delta < warpSize; delta *= 2)
#if __CUDACC_VER_MAJOR__ >= 9
        value += __shfl_down_sync(kFullMask, value, delta);
#else
        value += __shfl_down(value, delta);
#endif
    return value;
}

template<class scalar_t>
__device__ scalar_t warp_broadcast(scalar_t value, int lane_id) {
#if __CUDACC_VER_MAJOR__ >= 9
    return __shfl_sync(kFullMask, value, lane_id);
#else
    return __shfl(value, lane_id);
#endif
}

} // namespace at