#include #include #include #include #include "utils/checks.h" #include "utils/cuda.cuh" #include "inplace_abn.h" #include // Operations for reduce template struct SumOp { __device__ SumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {} __device__ __forceinline__ T operator()(int batch, int plane, int n) { return tensor[(batch * chn + plane) * sp + n]; } const T *tensor; const int chn; const int sp; }; template struct VarOp { __device__ VarOp(T m, const T *t, int c, int s) : mean(m), tensor(t), chn(c), sp(s) {} __device__ __forceinline__ T operator()(int batch, int plane, int n) { T val = tensor[(batch * chn + plane) * sp + n]; return (val - mean) * (val - mean); } const T mean; const T *tensor; const int chn; const int sp; }; template struct GradOp { __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; T _dz = dz[(batch * chn + plane) * sp + n]; return Pair(_dz, _y * _dz); } const T weight; const T bias; const T *z; const T *dz; const int chn; const int sp; }; /*********** * mean_var ***********/ template __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { int plane = blockIdx.x; T norm = T(1) / T(num * sp); T _mean = reduce>(SumOp(x, chn, sp), plane, num, sp) * norm; __syncthreads(); T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, sp) * norm; if (threadIdx.x == 0) { mean[plane] = _mean; var[plane] = _var; } } std::vector mean_var_cuda(at::Tensor x) { CHECK_CUDA_INPUT(x); // Extract dimensions int64_t num, chn, sp; get_dims(x, num, chn, sp); // Prepare output tensors auto mean = at::empty({chn}, x.options()); auto var = at::empty({chn}, x.options()); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { mean_var_kernel<<>>( x.data(), mean.data(), var.data(), num, chn, sp); })); return {mean, var}; } /********** * forward **********/ template __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, bool affine, float eps, int num, int chn, int sp) { int plane = blockIdx.x; T _mean = mean[plane]; T _var = var[plane]; T _weight = affine ? abs(weight[plane]) + eps : T(1); T _bias = affine ? bias[plane] : T(0); T mul = rsqrt(_var + eps) * _weight; for (int batch = 0; batch < num; ++batch) { for (int n = threadIdx.x; n < sp; n += blockDim.x) { T _x = x[(batch * chn + plane) * sp + n]; T _y = (_x - _mean) * mul + _bias; x[(batch * chn + plane) * sp + n] = _y; } } } at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, bool affine, float eps) { CHECK_CUDA_INPUT(x); CHECK_CUDA_INPUT(mean); CHECK_CUDA_INPUT(var); CHECK_CUDA_INPUT(weight); CHECK_CUDA_INPUT(bias); // Extract dimensions int64_t num, chn, sp; get_dims(x, num, chn, sp); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { forward_kernel<<>>( x.data(), mean.data(), var.data(), weight.data(), bias.data(), affine, eps, num, chn, sp); })); return x; } /*********** * edz_eydz ***********/ template __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { int plane = blockIdx.x; T _weight = affine ? abs(weight[plane]) + eps : 1.f; T _bias = affine ? bias[plane] : 0.f; Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, sp); __syncthreads(); if (threadIdx.x == 0) { edz[plane] = res.v1; eydz[plane] = res.v2; } } std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, bool affine, float eps) { CHECK_CUDA_INPUT(z); CHECK_CUDA_INPUT(dz); CHECK_CUDA_INPUT(weight); CHECK_CUDA_INPUT(bias); // Extract dimensions int64_t num, chn, sp; get_dims(z, num, chn, sp); auto edz = at::empty({chn}, z.options()); auto eydz = at::empty({chn}, z.options()); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { edz_eydz_kernel<<>>( z.data(), dz.data(), weight.data(), bias.data(), edz.data(), eydz.data(), affine, eps, num, chn, sp); })); return {edz, eydz}; } /*********** * backward ***********/ template __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) { int plane = blockIdx.x; T _weight = affine ? abs(weight[plane]) + eps : 1.f; T _bias = affine ? bias[plane] : 0.f; T _var = var[plane]; T _edz = edz[plane]; T _eydz = eydz[plane]; T _mul = _weight * rsqrt(_var + eps); T count = T(num * sp); for (int batch = 0; batch < num; ++batch) { for (int n = threadIdx.x; n < sp; n += blockDim.x) { T _dz = dz[(batch * chn + plane) * sp + n]; T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; } } } at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, at::Tensor edz, at::Tensor eydz, bool affine, float eps) { CHECK_CUDA_INPUT(z); CHECK_CUDA_INPUT(dz); CHECK_CUDA_INPUT(var); CHECK_CUDA_INPUT(weight); CHECK_CUDA_INPUT(bias); CHECK_CUDA_INPUT(edz); CHECK_CUDA_INPUT(eydz); // Extract dimensions int64_t num, chn, sp; get_dims(z, num, chn, sp); auto dx = at::zeros_like(z); // Run kernel dim3 blocks(chn); dim3 threads(getNumThreads(sp)); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { backward_kernel<<>>( z.data(), dz.data(), var.data(), weight.data(), bias.data(), edz.data(), eydz.data(), dx.data(), affine, eps, num, chn, sp); })); return dx; } /************** * activations **************/ template inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { // Create thrust pointers thrust::device_ptr th_z = thrust::device_pointer_cast(z); thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); auto stream = at::cuda::getCurrentCUDAStream(); thrust::transform_if(thrust::cuda::par.on(stream), th_dz, th_dz + count, th_z, th_dz, [slope] __device__ (const T& dz) { return dz * slope; }, [] __device__ (const T& z) { return z < 0; }); thrust::transform_if(thrust::cuda::par.on(stream), th_z, th_z + count, th_z, [slope] __device__ (const T& z) { return z / slope; }, [] __device__ (const T& z) { return z < 0; }); } void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { CHECK_CUDA_INPUT(z); CHECK_CUDA_INPUT(dz); int64_t count = z.numel(); AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { leaky_relu_backward_impl(z.data(), dz.data(), slope, count); })); } template inline void elu_backward_impl(T *z, T *dz, int64_t count) { // Create thrust pointers thrust::device_ptr th_z = thrust::device_pointer_cast(z); thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); auto stream = at::cuda::getCurrentCUDAStream(); thrust::transform_if(thrust::cuda::par.on(stream), th_dz, th_dz + count, th_z, th_z, th_dz, [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, [] __device__ (const T& z) { return z < 0; }); thrust::transform_if(thrust::cuda::par.on(stream), th_z, th_z + count, th_z, [] __device__ (const T& z) { return log1p(z); }, [] __device__ (const T& z) { return z < 0; }); } void elu_backward_cuda(at::Tensor z, at::Tensor dz) { CHECK_CUDA_INPUT(z); CHECK_CUDA_INPUT(dz); int64_t count = z.numel(); AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { elu_backward_impl(z.data(), dz.data(), count); })); }