Spaces:
Sleeping
Sleeping
#include <ATen/ATen.h> | |
#include <thrust/device_ptr.h> | |
#include <thrust/transform.h> | |
#include <vector> | |
#include "utils/checks.h" | |
#include "utils/cuda.cuh" | |
#include "inplace_abn.h" | |
#include <ATen/cuda/CUDAContext.h> | |
// Operations for reduce | |
template<typename T> | |
struct SumOp { | |
__device__ SumOp(const T *t, int c, int s) | |
: tensor(t), chn(c), sp(s) {} | |
__device__ __forceinline__ T operator()(int batch, int plane, int n) { | |
return tensor[(batch * chn + plane) * sp + n]; | |
} | |
const T *tensor; | |
const int chn; | |
const int sp; | |
}; | |
template<typename T> | |
struct VarOp { | |
__device__ VarOp(T m, const T *t, int c, int s) | |
: mean(m), tensor(t), chn(c), sp(s) {} | |
__device__ __forceinline__ T operator()(int batch, int plane, int n) { | |
T val = tensor[(batch * chn + plane) * sp + n]; | |
return (val - mean) * (val - mean); | |
} | |
const T mean; | |
const T *tensor; | |
const int chn; | |
const int sp; | |
}; | |
template<typename T> | |
struct GradOp { | |
__device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) | |
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} | |
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) { | |
T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; | |
T _dz = dz[(batch * chn + plane) * sp + n]; | |
return Pair<T>(_dz, _y * _dz); | |
} | |
const T weight; | |
const T bias; | |
const T *z; | |
const T *dz; | |
const int chn; | |
const int sp; | |
}; | |
/*********** | |
* mean_var | |
***********/ | |
template<typename T> | |
__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
T norm = T(1) / T(num * sp); | |
T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm; | |
__syncthreads(); | |
T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm; | |
if (threadIdx.x == 0) { | |
mean[plane] = _mean; | |
var[plane] = _var; | |
} | |
} | |
std::vector<at::Tensor> mean_var_cuda(at::Tensor x) { | |
CHECK_CUDA_INPUT(x); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(x, num, chn, sp); | |
// Prepare output tensors | |
auto mean = at::empty({chn}, x.options()); | |
auto var = at::empty({chn}, x.options()); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { | |
mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>( | |
x.data<scalar_t>(), | |
mean.data<scalar_t>(), | |
var.data<scalar_t>(), | |
num, chn, sp); | |
})); | |
return {mean, var}; | |
} | |
/********** | |
* forward | |
**********/ | |
template<typename T> | |
__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, | |
bool affine, float eps, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
T _mean = mean[plane]; | |
T _var = var[plane]; | |
T _weight = affine ? abs(weight[plane]) + eps : T(1); | |
T _bias = affine ? bias[plane] : T(0); | |
T mul = rsqrt(_var + eps) * _weight; | |
for (int batch = 0; batch < num; ++batch) { | |
for (int n = threadIdx.x; n < sp; n += blockDim.x) { | |
T _x = x[(batch * chn + plane) * sp + n]; | |
T _y = (_x - _mean) * mul + _bias; | |
x[(batch * chn + plane) * sp + n] = _y; | |
} | |
} | |
} | |
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, | |
bool affine, float eps) { | |
CHECK_CUDA_INPUT(x); | |
CHECK_CUDA_INPUT(mean); | |
CHECK_CUDA_INPUT(var); | |
CHECK_CUDA_INPUT(weight); | |
CHECK_CUDA_INPUT(bias); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(x, num, chn, sp); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { | |
forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>( | |
x.data<scalar_t>(), | |
mean.data<scalar_t>(), | |
var.data<scalar_t>(), | |
weight.data<scalar_t>(), | |
bias.data<scalar_t>(), | |
affine, eps, num, chn, sp); | |
})); | |
return x; | |
} | |
/*********** | |
* edz_eydz | |
***********/ | |
template<typename T> | |
__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, | |
T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
T _weight = affine ? abs(weight[plane]) + eps : 1.f; | |
T _bias = affine ? bias[plane] : 0.f; | |
Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp); | |
__syncthreads(); | |
if (threadIdx.x == 0) { | |
edz[plane] = res.v1; | |
eydz[plane] = res.v2; | |
} | |
} | |
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, | |
bool affine, float eps) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
CHECK_CUDA_INPUT(weight); | |
CHECK_CUDA_INPUT(bias); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(z, num, chn, sp); | |
auto edz = at::empty({chn}, z.options()); | |
auto eydz = at::empty({chn}, z.options()); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { | |
edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>( | |
z.data<scalar_t>(), | |
dz.data<scalar_t>(), | |
weight.data<scalar_t>(), | |
bias.data<scalar_t>(), | |
edz.data<scalar_t>(), | |
eydz.data<scalar_t>(), | |
affine, eps, num, chn, sp); | |
})); | |
return {edz, eydz}; | |
} | |
/*********** | |
* backward | |
***********/ | |
template<typename T> | |
__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, | |
const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
T _weight = affine ? abs(weight[plane]) + eps : 1.f; | |
T _bias = affine ? bias[plane] : 0.f; | |
T _var = var[plane]; | |
T _edz = edz[plane]; | |
T _eydz = eydz[plane]; | |
T _mul = _weight * rsqrt(_var + eps); | |
T count = T(num * sp); | |
for (int batch = 0; batch < num; ++batch) { | |
for (int n = threadIdx.x; n < sp; n += blockDim.x) { | |
T _dz = dz[(batch * chn + plane) * sp + n]; | |
T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; | |
dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; | |
} | |
} | |
} | |
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, | |
at::Tensor edz, at::Tensor eydz, bool affine, float eps) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
CHECK_CUDA_INPUT(var); | |
CHECK_CUDA_INPUT(weight); | |
CHECK_CUDA_INPUT(bias); | |
CHECK_CUDA_INPUT(edz); | |
CHECK_CUDA_INPUT(eydz); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(z, num, chn, sp); | |
auto dx = at::zeros_like(z); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { | |
backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>( | |
z.data<scalar_t>(), | |
dz.data<scalar_t>(), | |
var.data<scalar_t>(), | |
weight.data<scalar_t>(), | |
bias.data<scalar_t>(), | |
edz.data<scalar_t>(), | |
eydz.data<scalar_t>(), | |
dx.data<scalar_t>(), | |
affine, eps, num, chn, sp); | |
})); | |
return dx; | |
} | |
/************** | |
* activations | |
**************/ | |
template<typename T> | |
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { | |
// Create thrust pointers | |
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z); | |
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
thrust::transform_if(thrust::cuda::par.on(stream), | |
th_dz, th_dz + count, th_z, th_dz, | |
[slope] __device__ (const T& dz) { return dz * slope; }, | |
[] __device__ (const T& z) { return z < 0; }); | |
thrust::transform_if(thrust::cuda::par.on(stream), | |
th_z, th_z + count, th_z, | |
[slope] __device__ (const T& z) { return z / slope; }, | |
[] __device__ (const T& z) { return z < 0; }); | |
} | |
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
int64_t count = z.numel(); | |
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { | |
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count); | |
})); | |
} | |
template<typename T> | |
inline void elu_backward_impl(T *z, T *dz, int64_t count) { | |
// Create thrust pointers | |
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z); | |
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
thrust::transform_if(thrust::cuda::par.on(stream), | |
th_dz, th_dz + count, th_z, th_z, th_dz, | |
[] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, | |
[] __device__ (const T& z) { return z < 0; }); | |
thrust::transform_if(thrust::cuda::par.on(stream), | |
th_z, th_z + count, th_z, | |
[] __device__ (const T& z) { return log1p(z); }, | |
[] __device__ (const T& z) { return z < 0; }); | |
} | |
void elu_backward_cuda(at::Tensor z, at::Tensor dz) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
int64_t count = z.numel(); | |
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { | |
elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count); | |
})); | |
} | |