|
#pragma once |
|
#include <ATen/native/ForeachUtils.h> |
|
#include <ATen/native/cuda/MultiTensorApply.cuh> |
|
#include <ATen/OpMathType.h> |
|
|
|
namespace at { namespace native { |
|
|
|
namespace { |
|
|
|
|
|
template<int depth, typename T> |
|
__device__ bool init_args( |
|
T** args, |
|
TensorListMetadata<depth>& tl, |
|
int chunk_idx, |
|
int chunk_size, |
|
int tensor_loc) { |
|
bool all_aligned = true; |
|
for (int i = 0; i < depth; i++) { |
|
args[i] = (T*)tl.addresses[i][tensor_loc]; |
|
args[i] += chunk_idx * chunk_size; |
|
|
|
if (!is_aligned(args[i])) { |
|
all_aligned = false; |
|
} |
|
} |
|
return all_aligned; |
|
} |
|
|
|
|
|
template<int depth, typename T, typename T2> |
|
__device__ bool init_args( |
|
T** args, |
|
TensorListScalarListMetadata<T2, depth>& tl, |
|
int chunk_idx, |
|
int chunk_size, |
|
int tensor_loc) { |
|
bool all_aligned = true; |
|
for (int i = 0; i < depth; i++) { |
|
args[i] = (T*)tl.addresses[i][tensor_loc]; |
|
args[i] += chunk_idx * chunk_size; |
|
|
|
if (!is_aligned(args[i])) { |
|
all_aligned = false; |
|
} |
|
} |
|
return all_aligned; |
|
} |
|
|
|
template<int depth, typename T> |
|
__device__ bool init_args( |
|
T** args, |
|
FusedOptimizerTensorListMetadata<depth>& tl, |
|
int chunk_idx, |
|
int chunk_size, |
|
int tensor_loc) { |
|
bool all_aligned = true; |
|
for (int i = 0; i < depth; i++) { |
|
args[i] = (T*)tl.addresses[i][tensor_loc]; |
|
args[i] += chunk_idx * chunk_size; |
|
|
|
if (!is_aligned(args[i])) { |
|
all_aligned = false; |
|
} |
|
} |
|
return all_aligned; |
|
} |
|
|
|
template<int depth, typename T> |
|
__device__ void load_args(T r_args[][kILP], T** args, int i_start, int chunk_size, int n) { |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
int i = i_start + threadIdx.x + ii * blockDim.x; |
|
for (int r_index = 0; r_index < depth; r_index++) { |
|
r_args[r_index][ii] = 0; |
|
if(i < n && i < chunk_size) { |
|
r_args[r_index][ii] = args[r_index][i]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template<typename T> |
|
__device__ void store_args(T* dst, T* src, int i_start, int chunk_size, int n) { |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
int i = i_start + threadIdx.x + ii * blockDim.x; |
|
if(i < n && i < chunk_size) |
|
dst[i] = src[ii]; |
|
} |
|
} |
|
|
|
template<int res_arg_index, typename Op, typename T, typename opmath_t> |
|
__device__ __forceinline__ void binary_op_scalar( |
|
T r_args[][kILP], |
|
T** args, |
|
opmath_t scalar, |
|
int n, |
|
int chunk_size, |
|
bool all_aligned, |
|
Op op) { |
|
|
|
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { |
|
|
|
load_store(r_args[0], args[0], 0, i_start); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(scalar))); |
|
} |
|
|
|
load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} |
|
else { |
|
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { |
|
|
|
load_args<1>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(scalar))); |
|
} |
|
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
|
|
template<int res_arg_index, typename Op, typename T, typename opmath_t> |
|
__device__ __forceinline__ void pointwise_op_scalar( |
|
T r_args[][kILP], |
|
T** args, |
|
opmath_t scalar, |
|
int n, |
|
int chunk_size, |
|
bool all_aligned, |
|
Op op) { |
|
|
|
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { |
|
|
|
load_store(r_args[0], args[0], 0, i_start); |
|
load_store(r_args[1], args[1], 0, i_start); |
|
load_store(r_args[2], args[2], 0, i_start); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(static_cast<opmath_t>(r_args[0][ii]) + |
|
scalar * op(static_cast<opmath_t>(r_args[1][ii]), |
|
static_cast<opmath_t>(r_args[2][ii]))); |
|
} |
|
|
|
load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} |
|
else { |
|
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { |
|
|
|
load_args<3>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(static_cast<opmath_t>(r_args[0][ii]) + |
|
scalar * op(static_cast<opmath_t>(r_args[1][ii]), |
|
static_cast<opmath_t>(r_args[2][ii]))); |
|
} |
|
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpScalarFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListMetadata<depth>& tl, |
|
Op op, |
|
opmath_t scalar) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
binary_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpScalarListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListScalarListMetadata<opmath_t, depth>& tl, |
|
Op op) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
opmath_t scalar = tl.scalar_vals[tensor_loc]; |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
binary_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct BinaryOpListAlphaFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListMetadata<depth>& tl, |
|
Op op, |
|
opmath_t alpha) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
|
|
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { |
|
|
|
load_store(r_args[0], args[0], 0, i_start); |
|
load_store(r_args[1], args[1], 0, i_start); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]), |
|
alpha * static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
|
|
load_store(args[res_arg_index], r_args[0], i_start , 0); |
|
} |
|
} |
|
else { |
|
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { |
|
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]), |
|
alpha * static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct ZeroFunctor { |
|
__device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListMetadata<1>& tl) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
|
|
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = 0; |
|
} |
|
|
|
load_store(args[0], r_args[0], i_start, 0); |
|
} |
|
} |
|
else { |
|
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = 0; |
|
} |
|
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct UnaryOpFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListMetadata<depth>& tl, |
|
Op op) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
|
|
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { |
|
|
|
load_store(r_args[0], args[0], 0, i_start); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]))); |
|
} |
|
|
|
load_store(args[res_arg_index], r_args[0], i_start, 0); |
|
} |
|
} |
|
else { |
|
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { |
|
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]))); |
|
} |
|
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct PointwiseOpScalarFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListMetadata<depth>& tl, |
|
Op op, |
|
opmath_t scalar) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
pointwise_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
|
|
template<typename T, int depth, int r_args_depth, int res_arg_index> |
|
struct PointwiseOpScalarListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListScalarListMetadata<opmath_t, depth>& tl, |
|
Op op) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
opmath_t scalar = tl.scalar_vals[tensor_loc]; |
|
n -= chunk_idx * chunk_size; |
|
T r_args[r_args_depth][kILP]; |
|
|
|
pointwise_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op); |
|
} |
|
}; |
|
|
|
template<typename T, int depth> |
|
struct PointwiseOpListFunctor { |
|
using opmath_t = at::opmath_type<T>; |
|
template<typename Op> __device__ __forceinline__ void operator() ( |
|
int chunk_size, |
|
TensorListMetadata<depth>& tl, |
|
Op op) { |
|
int tensor_loc = tl.block_to_tensor[blockIdx.x]; |
|
int chunk_idx = tl.block_to_chunk[blockIdx.x]; |
|
int n = tl.numel_for_tensor[tensor_loc]; |
|
|
|
T* args[depth]; |
|
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc); |
|
n -= chunk_idx * chunk_size; |
|
T r_args[depth - 1][kILP]; |
|
|
|
|
|
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { |
|
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { |
|
|
|
load_store(r_args[0], args[0], 0, i_start); |
|
load_store(r_args[1], args[1], 0, i_start); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
|
|
load_store(args[2], r_args[0], i_start , 0); |
|
} |
|
} |
|
else { |
|
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { |
|
load_args<depth - 1>(r_args, args, i_start, chunk_size, n); |
|
#pragma unroll |
|
for(int ii = 0; ii < kILP; ii++) { |
|
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]), |
|
static_cast<opmath_t>(r_args[1][ii]))); |
|
} |
|
store_args(args[2], r_args[0], i_start, chunk_size, n); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
} |
|
}} |
|
|