|
#pragma once |
|
|
|
#include <ATen/core/Tensor.h> |
|
#include <ATen/Dispatch.h> |
|
#include <ATen/AccumulateType.h> |
|
#include <ATen/ceil_div.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <ATen/cuda/DeviceUtils.cuh> |
|
#include <ATen/native/cuda/block_reduce.cuh> |
|
#include <ATen/native/cuda/DeviceSqrt.cuh> |
|
#include <ATen/native/cuda/LaunchUtils.h> |
|
#include <c10/macros/Macros.h> |
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS |
|
#include <ATen/Functions.h> |
|
#else |
|
#include <ATen/ops/empty.h> |
|
#include <ATen/ops/empty_like.h> |
|
#include <ATen/ops/zeros.h> |
|
#endif |
|
|
|
namespace at { namespace native { |
|
|
|
|
|
#if defined(USE_ROCM) |
|
constexpr int MAX_BLOCK_SIZE = 256; |
|
#else |
|
constexpr int MAX_BLOCK_SIZE = 512; |
|
#endif |
|
|
|
constexpr unsigned MAX_GRID_SIZE = 65535u; |
|
|
|
|
|
static int getNumThreads(int nElem) { |
|
#if defined(USE_ROCM) |
|
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; |
|
#else |
|
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; |
|
#endif |
|
for (int i = 0; i != 5; ++i) { |
|
if (nElem <= threadSizes[i]) { |
|
return threadSizes[i]; |
|
} |
|
} |
|
return MAX_BLOCK_SIZE; |
|
} |
|
|
|
|
|
__device__ __forceinline__ int getMSB(int val) { |
|
return 31 - __clz(val); |
|
} |
|
|
|
template <typename scalar_t, typename accscalar_t> |
|
struct Float2 { |
|
accscalar_t v1, v2; |
|
__device__ Float2() {} |
|
__device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {} |
|
__device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {} |
|
__device__ Float2& operator+=(const Float2& a) { |
|
v1 += a.v1; |
|
v2 += a.v2; |
|
return *this; |
|
} |
|
__device__ friend Float2 operator+(Float2 a, const Float2& b) { |
|
a += b; |
|
return a; |
|
} |
|
}; |
|
|
|
template <typename scalar_t, typename accscalar_t, typename PTA> |
|
struct GradOp { |
|
__device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) |
|
: mean(m), input(i), grad_output(g) {} |
|
__device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) { |
|
accscalar_t g = grad_output[batch][plane][n]; |
|
accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean; |
|
return Float2<scalar_t, accscalar_t>(g, g * c); |
|
} |
|
const accscalar_t mean; |
|
const PTA& input; |
|
const PTA& grad_output; |
|
}; |
|
|
|
template <typename acc_t> |
|
struct SumReduceOp { |
|
__device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } |
|
|
|
__device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { |
|
return WARP_SHFL_DOWN(data, offset); |
|
} |
|
}; |
|
|
|
template <typename scalar_t, typename accscalar_t> |
|
struct SumReduceOp<Float2<scalar_t, accscalar_t>> { |
|
using acc_t = Float2<scalar_t, accscalar_t>; |
|
|
|
__device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } |
|
|
|
__device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { |
|
return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)}; |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, typename Op, typename PTA> |
|
__device__ scalar_t reduce(Op op, PTA tensor, int plane) { |
|
|
|
scalar_t sum = static_cast<scalar_t>(0); |
|
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { |
|
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { |
|
sum += op(batch, plane, x); |
|
} |
|
} |
|
__shared__ scalar_t shared[C10_WARP_SIZE]; |
|
SumReduceOp<scalar_t> reduce_op; |
|
sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared); |
|
if (threadIdx.x == 0 && threadIdx.y == 0) { |
|
shared[0] = sum; |
|
} |
|
__syncthreads(); |
|
|
|
return shared[0]; |
|
} |
|
|
|
constexpr int ELEMENTS_PER_ITER = 4; |
|
constexpr int ELEMENTS_PER_THREAD = 16; |
|
constexpr int OPTIMAL_TILE_W = 32; |
|
constexpr int MAX_H_BLOCK = 128; |
|
|
|
__host__ void flexible_launch_configs( |
|
const int reduction, |
|
const int stride, |
|
dim3 &block, |
|
dim3 &grid, |
|
const bool coop_flag = false) { |
|
int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); |
|
int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)), |
|
MAX_BLOCK_SIZE / block_x); |
|
if (block_x * block_y != MAX_BLOCK_SIZE) { |
|
block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); |
|
} |
|
|
|
int grid_x = at::ceil_div(stride, block_x); |
|
int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); |
|
if (coop_flag) { |
|
|
|
grid_y = grid_y < 8 ? 1 : grid_y; |
|
} |
|
|
|
block.x = block_x; |
|
block.y = block_y; |
|
block.z = 1; |
|
grid.x = grid_x; |
|
grid.y = grid_y; |
|
grid.z = 1; |
|
} |
|
|
|
template<typename T, typename C> |
|
__device__ __forceinline__ void welford_merge_element(C& count, |
|
T& mean, |
|
T& m2n, |
|
const C& count_new, |
|
const T& mean_new, |
|
const T& m2n_new) { |
|
T factor = T(1.0) / ::max(1, (count + count_new)); |
|
T delta0 = mean - mean_new; |
|
mean = (mean_new * count_new + mean * count) * factor; |
|
m2n += m2n_new + delta0 * delta0 * count_new * count * factor; |
|
count += count_new; |
|
} |
|
|
|
|
|
template<typename T, typename C> |
|
__device__ __forceinline__ void welford_merge_block_vertical(C& count, |
|
T& mean, |
|
T& m2n, |
|
C* shmem_count, |
|
T* shmem_mean, |
|
T* shmem_m2n) { |
|
|
|
auto address_base = threadIdx.x + threadIdx.y * blockDim.x; |
|
|
|
#pragma unroll |
|
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { |
|
if (threadIdx.y < offset*2) { |
|
shmem_mean[address_base] = mean; |
|
shmem_m2n[address_base] = m2n; |
|
shmem_count[address_base] = count; |
|
} |
|
__syncthreads(); |
|
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { |
|
auto address = address_base + offset * blockDim.x; |
|
|
|
auto count_new = shmem_count[address]; |
|
auto mean_new = shmem_mean[address]; |
|
auto m2n_new = shmem_m2n[address]; |
|
|
|
welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); |
|
} |
|
} |
|
} |
|
|
|
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t> |
|
__global__ void batch_norm_transform_input_kernel( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input, |
|
GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output, |
|
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_, |
|
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> weight, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> bias, |
|
stat_accscalar_t epsilon) { |
|
|
|
index_t plane = blockIdx.x; |
|
|
|
if (plane >= input.size(1)) { |
|
return; |
|
} |
|
|
|
stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1); |
|
stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0); |
|
stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]); |
|
stat_accscalar_t invstd; |
|
if (train) { |
|
invstd = var_or_invstd[plane]; |
|
} else { |
|
invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon); |
|
} |
|
|
|
index_t bs = input.size(0); |
|
index_t fs = input.size(2); |
|
|
|
index_t bstep = blockDim.y * gridDim.y; |
|
for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { |
|
auto o = output[batch][plane]; |
|
auto i = input[batch][plane]; |
|
for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { |
|
o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta); |
|
} |
|
} |
|
} |
|
|
|
struct InvStd { |
|
template <typename T> |
|
__device__ __forceinline__ T operator()(T var, double epsilon) const { |
|
T invstd = 0; |
|
if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) { |
|
invstd = static_cast<T>(1) / device_sqrt(var + epsilon); |
|
} |
|
return invstd; |
|
} |
|
}; |
|
|
|
struct Var { |
|
template <typename T> |
|
__device__ __forceinline__ T operator()(T var, double epsilon) const { |
|
return var; |
|
} |
|
}; |
|
|
|
template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> |
|
__global__ void batch_norm_collect_statistics_kernel( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input, |
|
const stat_accscalar_t epsilon, |
|
const stat_accscalar_t momentum, |
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean, |
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) { |
|
|
|
__shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; |
|
|
|
int plane = blockIdx.x; |
|
int N = input.size(0) * input.size(2); |
|
int tid = threadIdx.x + threadIdx.y * blockDim.x; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; |
|
|
|
|
|
stat_accscalar_t avg = 0; |
|
stat_accscalar_t var_n = 0; |
|
int n = 0; |
|
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { |
|
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { |
|
stat_accscalar_t v = input[batch][plane][x]; |
|
stat_accscalar_t d1 = v - avg; |
|
n++; |
|
avg += d1 / n; |
|
var_n += d1 * (v - avg); |
|
} |
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { |
|
stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); |
|
int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); |
|
stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); |
|
var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; |
|
avg = (n * avg + o_n * o_avg) * factor; |
|
n += o_n; |
|
} |
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
if (tid % C10_WARP_SIZE == 0) { |
|
shared_n[tid / C10_WARP_SIZE] = n; |
|
shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; |
|
shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; |
|
} |
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
if (tid < C10_WARP_SIZE) { |
|
n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); |
|
avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); |
|
var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); |
|
} |
|
for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { |
|
stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); |
|
int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); |
|
stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); |
|
var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; |
|
avg = (n * avg + o_n * o_avg) * factor; |
|
n += o_n; |
|
} |
|
|
|
|
|
if (tid == 0) { |
|
if (save_mean.data() != NULL) { |
|
save_mean[plane] = avg; |
|
} |
|
if (save_transformed_var.data() != NULL) { |
|
save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); |
|
} |
|
} |
|
|
|
} |
|
|
|
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> |
|
__global__ void batch_norm_backward_kernel( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input, |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output, |
|
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input, |
|
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight, |
|
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd, |
|
bool train, |
|
stat_accscalar_t epsilon) { |
|
|
|
index_t plane = blockIdx.x; |
|
index_t N = grad_output.size(0) * grad_output.size(2); |
|
|
|
stat_accscalar_t mean, invstd; |
|
if (train) { |
|
mean = save_mean[plane]; |
|
invstd = save_invstd[plane]; |
|
} else { |
|
mean = static_cast<stat_accscalar_t>(running_mean[plane]); |
|
invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon); |
|
} |
|
|
|
stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1); |
|
stat_accscalar_t norm = stat_accscalar_t(1) / N; |
|
|
|
|
|
|
|
|
|
GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output); |
|
auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane); |
|
|
|
stat_accscalar_t grad_output_sum = res.v1; |
|
stat_accscalar_t dot_p = res.v2; |
|
|
|
stat_accscalar_t grad_mean = grad_output_sum * norm; |
|
stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; |
|
stat_accscalar_t grad_scale = invstd * weight_val; |
|
|
|
if (grad_input.data() != NULL) { |
|
for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { |
|
for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { |
|
input_scalar_t go = grad_output[batch][plane][x]; |
|
if (train) { |
|
stat_accscalar_t inp = input[batch][plane][x]; |
|
stat_accscalar_t proj = (inp - mean) * proj_scale; |
|
grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale); |
|
} else { |
|
grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale); |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (grad_weight.size(0) > 0) { |
|
if (threadIdx.x == 0) { |
|
grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd); |
|
} |
|
} |
|
|
|
if (grad_bias.size(0) > 0) { |
|
if (threadIdx.x == 0) { |
|
grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum); |
|
} |
|
} |
|
} |
|
|
|
template <typename scalar_t, typename accscalar_t, typename index_t> |
|
__global__ void batch_norm_reduce_statistics_kernel( |
|
const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean, |
|
const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd, |
|
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean, |
|
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd, |
|
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean, |
|
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var, |
|
const accscalar_t epsilon, |
|
const accscalar_t momentum, |
|
const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) { |
|
|
|
int feature_size = vec_mean.size(1); |
|
int world_size = vec_mean.size(0); |
|
|
|
int bid = blockIdx.x; |
|
int tid = threadIdx.x; |
|
|
|
|
|
for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) { |
|
accscalar_t avg = 0; |
|
accscalar_t var_n = 0; |
|
index_t n = 0; |
|
for (int j = 0; j < world_size; j++) { |
|
scalar_t count = counts[j]; |
|
accscalar_t m = vec_mean[j][i]; |
|
accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]); |
|
v = (v * v - epsilon) * count; |
|
accscalar_t factor = 1.0 / (n + count); |
|
var_n += v + (avg - m) * (avg - m) * n * count * factor; |
|
avg = n * factor * avg + count * factor * m; |
|
n += count; |
|
} |
|
mean[i] = avg; |
|
invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon); |
|
if (running_mean.data() != NULL) { |
|
running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg); |
|
} |
|
accscalar_t unbiasedVar = var_n / (n - 1); |
|
if (running_var.data() != NULL) { |
|
running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar); |
|
} |
|
} |
|
|
|
} |
|
|
|
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> |
|
__global__ void batch_norm_backward_reduce_kernel( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input, |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output, |
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean, |
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd, |
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy, |
|
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu, |
|
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight, |
|
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) { |
|
|
|
index_t plane = blockIdx.x; |
|
|
|
stat_accscalar_t r_mean = mean[plane]; |
|
stat_accscalar_t factor = invstd[plane]; |
|
|
|
GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output); |
|
auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane); |
|
|
|
if (threadIdx.x == 0) { |
|
if (grad_weight.size(0) > 0) { |
|
grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor); |
|
} |
|
if (grad_bias.size(0) > 0) { |
|
grad_bias[plane] = static_cast<stat_scalar_t>(res.v1); |
|
} |
|
if (sum_dy.size(0) > 0) { |
|
sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1); |
|
} |
|
if (sum_dy_xmu.size(0) > 0) { |
|
sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2); |
|
} |
|
} |
|
} |
|
|
|
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> |
|
__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input, |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu, |
|
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input, |
|
const stat_accscalar_t norm_fct) { |
|
index_t plane = blockIdx.x; |
|
|
|
if (plane >= input.size(1)) { |
|
return; |
|
} |
|
|
|
stat_accscalar_t m_c = mean[plane]; |
|
stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct; |
|
stat_accscalar_t factor_1_c = invstd[plane]; |
|
stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1); |
|
factor_2_c *= factor_1_c; |
|
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct; |
|
|
|
index_t bs = input.size(0); |
|
index_t fs = input.size(2); |
|
|
|
index_t bstep = blockDim.y * gridDim.y; |
|
for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { |
|
auto g_i = grad_input[batch][plane]; |
|
auto g_o = grad_output[batch][plane]; |
|
auto i = input[batch][plane]; |
|
for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { |
|
g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c); |
|
} |
|
} |
|
} |
|
|
|
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> |
|
__global__ void batch_norm_backward_elemt_kernel( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input, |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu, |
|
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input, |
|
const int* __restrict__ numel, const int world_size) { |
|
int64_t total_numel = 0; |
|
for (int i = 0; i < world_size; i ++) { |
|
total_numel += numel[i]; |
|
} |
|
|
|
const stat_accscalar_t norm_fct = |
|
static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel); |
|
batch_norm_backward_elemt_kernel_impl( |
|
input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); |
|
} |
|
|
|
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> |
|
__global__ void batch_norm_backward_elemt_kernel( |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input, |
|
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd, |
|
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy, |
|
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu, |
|
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input, |
|
const stat_accscalar_t norm_fct) { |
|
batch_norm_backward_elemt_kernel_impl( |
|
input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); |
|
} |
|
|
|
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> |
|
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor( |
|
const Tensor& t, c10::string_view var_name) { |
|
constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t>::value; |
|
const auto actual_type = t.scalar_type(); |
|
TORCH_CHECK(actual_type == expect_type, "Expected ", var_name, |
|
" to have type ", expect_type, " but got ", actual_type); |
|
return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>(); |
|
} |
|
|
|
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> |
|
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy( |
|
const Tensor& t, c10::string_view var_name) { |
|
if (!t.defined()) { |
|
const std::array<index_t, dim> zeros{{0}}; |
|
return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data()); |
|
} |
|
return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name); |
|
} |
|
|
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> |
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, |
|
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, |
|
bool train, double epsilon, std::array<bool,3> grad_input_mask) { |
|
|
|
using accscalar_t = at::acc_type<stat_scalar_t, true>; |
|
Tensor grad_input_; |
|
Tensor grad_input_reshaped; |
|
Tensor grad_weight_; |
|
Tensor grad_bias_; |
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); |
|
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); |
|
|
|
if (grad_input_mask[0]) { |
|
grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); |
|
} |
|
if (grad_input_mask[1]) { |
|
grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
} |
|
if (grad_input_mask[2]) { |
|
grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
} |
|
|
|
auto input = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); |
|
auto grad_output = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); |
|
auto grad_input = packed_accessor_or_dummy< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); |
|
auto weight = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); |
|
auto grad_weight = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); |
|
auto grad_bias = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); |
|
auto running_mean = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean"); |
|
auto running_var = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var"); |
|
auto save_mean = packed_accessor_or_dummy< |
|
accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean"); |
|
auto save_invstd = packed_accessor_or_dummy< |
|
accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd"); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
dim3 blocks(input.size(1)); |
|
int tf = getNumThreads(input.size(2)); |
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf)); |
|
|
|
batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>> |
|
(input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, |
|
save_mean, save_invstd, train, epsilon); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
|
return std::make_tuple(grad_input_, grad_weight_, grad_bias_); |
|
} |
|
|
|
template<typename scalar_t, typename index_t, typename VarTransform> |
|
void batch_norm_stats_cuda_template( |
|
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) { |
|
|
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
int64_t n_input = input_.size(1); |
|
Tensor dummy_mean_; |
|
Tensor dummy_var_; |
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); |
|
|
|
resize_output(out_mean, {n_input}); |
|
resize_output(out_invstd, {n_input}); |
|
auto input = get_packed_accessor< |
|
scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); |
|
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && |
|
out_invstd.sizes()[0]); |
|
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && |
|
out_mean.sizes()[0]); |
|
|
|
auto mean = packed_accessor_or_dummy< |
|
accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean"); |
|
auto invstd = packed_accessor_or_dummy< |
|
accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd"); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
dim3 blocks(input.size(1)); |
|
int tf = getNumThreads(input.size(2)); |
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf)); |
|
batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>> |
|
(input, epsilon, 0.0, mean, invstd); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> |
|
void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, |
|
const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { |
|
|
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>; |
|
int64_t n_input = input_.size(1); |
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); |
|
auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); |
|
|
|
auto input = get_packed_accessor< |
|
input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); |
|
auto output = get_packed_accessor< |
|
input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output"); |
|
auto weight = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight"); |
|
auto bias = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias"); |
|
auto mean = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean"); |
|
auto invstd = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd"); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
|
|
const double dummy_epsilon = 1e-5; |
|
|
|
|
|
|
|
|
|
|
|
int tf = std::max<int>(getNumThreads(input.size(2)/4), |
|
std::min<int>(getNumThreads(input.size(2)), 64)); |
|
int tb = std::max<int>(64/tf, 1); |
|
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1), |
|
(input.size(0)+tb-1)/tb))); |
|
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); |
|
dim3 threads_trans(tf, tb); |
|
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>> |
|
(input, output, mean, invstd, weight, bias, dummy_epsilon); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
template<typename scalar_t, typename accscalar_t, typename index_t> |
|
std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_, |
|
const Tensor& running_mean_, const Tensor& running_var_, |
|
double momentum, double epsilon, const Tensor& counts_) { |
|
|
|
Tensor save_mean_; |
|
Tensor save_invstd_; |
|
|
|
auto features = mean_.size(1); |
|
auto input_options = mean_.options(); |
|
if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) { |
|
input_options = input_options.dtype(ScalarType::Float); |
|
} |
|
save_mean_ = at::empty({features}, input_options); |
|
save_invstd_ = at::empty({features}, input_options); |
|
|
|
auto mean = packed_accessor_or_dummy< |
|
accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean"); |
|
auto invstd = packed_accessor_or_dummy< |
|
accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd"); |
|
auto running_mean = packed_accessor_or_dummy< |
|
scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean"); |
|
auto running_var = packed_accessor_or_dummy< |
|
scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean"); |
|
auto counts = packed_accessor_or_dummy< |
|
scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts"); |
|
|
|
auto save_mean = get_packed_accessor< |
|
accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean"); |
|
auto save_invstd = get_packed_accessor< |
|
accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd"); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
int block = getNumThreads(features); |
|
int grid = std::max<int>(1, features/block); |
|
batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>> |
|
(mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
|
return std::make_tuple(save_mean_, save_invstd_); |
|
} |
|
|
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> |
|
std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_, |
|
const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, |
|
const bool input_g, const bool weight_g, const bool bias_g) { |
|
|
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>; |
|
int64_t n_input = input_.size(1); |
|
Tensor sum_dy_; |
|
Tensor sum_dy_xmu_; |
|
Tensor grad_weight_; |
|
Tensor grad_bias_; |
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); |
|
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); |
|
|
|
if (input_g) { |
|
sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
} |
|
if (weight_g) { |
|
grad_weight_ = at::empty({n_input}, weight_.options()); |
|
} |
|
if (bias_g) { |
|
grad_bias_ = at::empty({n_input}, weight_.options()); |
|
} |
|
|
|
auto input = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); |
|
auto grad_output = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); |
|
auto grad_weight = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); |
|
auto grad_bias = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); |
|
auto mean = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); |
|
auto invstd = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); |
|
auto sum_dy = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); |
|
auto sum_dy_xmu = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); |
|
|
|
auto batch_size = input_reshaped.size(0); |
|
auto feature_size = input_reshaped.size(2); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
int warp_size = at::cuda::warp_size(); |
|
int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size); |
|
|
|
int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y); |
|
const dim3 block(block_x, block_y); |
|
const dim3 grid(n_input); |
|
|
|
batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>> |
|
(input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
|
return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); |
|
} |
|
|
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> |
|
Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, |
|
const Tensor& mean_, const Tensor& invstd_, |
|
const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) { |
|
|
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>; |
|
int64_t n_input = input_.size(1); |
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); |
|
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); |
|
auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
|
|
auto input = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); |
|
auto grad_input = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); |
|
auto grad_output = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); |
|
auto mean = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); |
|
auto invstd = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); |
|
auto weight = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); |
|
auto sum_dy = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); |
|
auto sum_dy_xmu = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
|
|
|
|
|
|
|
|
int tf = std::max<int>(getNumThreads(input.size(2)/4), |
|
std::min<int>(getNumThreads(input.size(2)), 64)); |
|
int tb = std::max<int>(64/tf, 1); |
|
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1), |
|
(input.size(0)+tb-1)/tb))); |
|
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); |
|
dim3 threads_trans(tf, tb); |
|
auto reduction_size = input_.numel() / n_input; |
|
auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size); |
|
batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> |
|
<<<blocks_trans, threads_trans, 0, stream>>> |
|
(input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
|
return grad_input_reshaped.view(input_.sizes()); |
|
} |
|
|
|
template<typename input_scalar_t, typename stat_scalar_t, typename index_t> |
|
Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, |
|
const Tensor& mean_, const Tensor& invstd_, |
|
const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) { |
|
|
|
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>; |
|
int64_t n_input = input_.size(1); |
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); |
|
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); |
|
auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
|
|
auto input = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); |
|
auto grad_input = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); |
|
auto grad_output = get_packed_accessor< |
|
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); |
|
auto mean = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); |
|
auto invstd = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); |
|
auto weight = packed_accessor_or_dummy< |
|
stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); |
|
auto sum_dy = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); |
|
auto sum_dy_xmu = packed_accessor_or_dummy< |
|
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
|
|
|
|
|
|
|
|
int tf = std::max<int>(getNumThreads(input.size(2)/4), |
|
std::min<int>(getNumThreads(input.size(2)), 64)); |
|
int tb = std::max<int>(64/tf, 1); |
|
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1), |
|
(input.size(0)+tb-1)/tb))); |
|
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); |
|
dim3 threads_trans(tf, tb); |
|
batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>> |
|
(input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.data_ptr<int>(), count.numel()); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
|
return grad_input_reshaped.view(input_.sizes()); |
|
} |
|
|
|
|
|
|
|
template |
|
<typename VarTransform, |
|
typename scalar_t, |
|
typename accscalar_t, |
|
int PARALLEL_LOADS> |
|
__global__ void |
|
batch_norm_collect_statistics_channels_last_kernel( |
|
const scalar_t* __restrict__ input, |
|
accscalar_t* __restrict__ out_mean, |
|
accscalar_t* __restrict__ out_invstd, |
|
volatile accscalar_t* staging_data, |
|
int* semaphores, |
|
const int reduction_size, |
|
const int stride, |
|
accscalar_t epsilon) { |
|
|
|
accscalar_t x_mean[PARALLEL_LOADS]; |
|
accscalar_t m_2_n[PARALLEL_LOADS]; |
|
int count[PARALLEL_LOADS]; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < PARALLEL_LOADS; i++) { |
|
x_mean[i] = accscalar_t(0); |
|
m_2_n[i] = accscalar_t(0); |
|
count[i] = accscalar_t(0); |
|
} |
|
|
|
|
|
|
|
int inner_loop_stride = blockDim.y * gridDim.y; |
|
|
|
|
|
int m_offset = blockIdx.y * blockDim.y + threadIdx.y; |
|
int c_offset = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); |
|
int address_base = m_offset * stride + c_offset; |
|
int address_increment = inner_loop_stride * stride; |
|
|
|
for (int i = 0; i < loop_count; i++) { |
|
accscalar_t x_math[PARALLEL_LOADS]; |
|
accscalar_t x_count_inv[PARALLEL_LOADS]; |
|
accscalar_t is_valid[PARALLEL_LOADS]; |
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < PARALLEL_LOADS; j++) { |
|
if (c_offset < stride && m_offset < reduction_size) { |
|
x_math[j] = input[address_base]; |
|
count[j]++; |
|
x_count_inv[j] = accscalar_t(1) / count[j]; |
|
is_valid[j] = accscalar_t(1); |
|
} else { |
|
x_math[j] = accscalar_t(0); |
|
x_count_inv[j] = accscalar_t(0); |
|
is_valid[j] = accscalar_t(0); |
|
} |
|
m_offset += inner_loop_stride; |
|
address_base += address_increment; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < PARALLEL_LOADS; j++) { |
|
accscalar_t delta0 = x_math[j] - x_mean[j]; |
|
x_mean[j] += delta0 * x_count_inv[j]; |
|
accscalar_t delta1 = x_math[j] - x_mean[j]; |
|
m_2_n[j] += delta0 * delta1 * is_valid[j]; |
|
} |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int j = 1; j < PARALLEL_LOADS; j++) { |
|
welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); |
|
} |
|
|
|
|
|
auto mean_th = x_mean[0]; |
|
auto m2_th = m_2_n[0]; |
|
auto count_th = count[0]; |
|
|
|
|
|
static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; |
|
static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; |
|
static __shared__ int shmem_count[MAX_BLOCK_SIZE]; |
|
|
|
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); |
|
|
|
if (gridDim.y > 1) { |
|
volatile accscalar_t* staging_mean = staging_data; |
|
volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; |
|
volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]); |
|
|
|
address_base = c_offset + blockIdx.y * stride; |
|
|
|
if (threadIdx.y == 0 && c_offset < stride) { |
|
staging_mean[address_base] = mean_th; |
|
staging_m2n[address_base] = m2_th; |
|
staging_count[address_base] = count_th; |
|
} |
|
|
|
__threadfence(); |
|
__syncthreads(); |
|
|
|
__shared__ bool is_last_block_done; |
|
|
|
if (threadIdx.x == 0 && threadIdx.y == 0) { |
|
int old = atomicAdd(&semaphores[blockIdx.x], 1); |
|
is_last_block_done = (old == (gridDim.y-1)); |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
if (is_last_block_done) { |
|
count_th = 0; |
|
mean_th = accscalar_t(0.0); |
|
m2_th = accscalar_t(0.0); |
|
|
|
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { |
|
address_base = c_offset + y * stride; |
|
int count_new = c_offset < stride ? staging_count[address_base] : 0; |
|
accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); |
|
accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); |
|
|
|
welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); |
|
} |
|
|
|
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); |
|
if (threadIdx.y == 0 && c_offset < stride) { |
|
out_mean[c_offset] = static_cast<accscalar_t>(mean_th); |
|
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); |
|
} |
|
} |
|
} else { |
|
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { |
|
out_mean[c_offset] = static_cast<accscalar_t>(mean_th); |
|
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template < |
|
typename scalar_t, |
|
typename accscalar_t, |
|
typename layerscalar_t, |
|
int PARALLEL_LOADS> |
|
__global__ void batch_norm_transform_input_channels_last_kernel( |
|
const scalar_t* __restrict__ input, |
|
const scalar_t* __restrict__ z, |
|
const accscalar_t* __restrict__ mean, |
|
const accscalar_t* __restrict__ inv_std, |
|
const layerscalar_t* __restrict__ weight, |
|
const layerscalar_t* __restrict__ shift, |
|
scalar_t* __restrict__ out, |
|
const int reduction_size, |
|
const int stride, |
|
const bool fuse_relu) { |
|
|
|
|
|
int inner_loop_stride = blockDim.y * gridDim.y; |
|
|
|
|
|
int m_offset = blockIdx.y * blockDim.y + threadIdx.y; |
|
int c_offset = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
if (c_offset >= stride || m_offset >= reduction_size) { |
|
return; |
|
} |
|
|
|
auto m_c = mean[c_offset]; |
|
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]); |
|
auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]); |
|
auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]); |
|
|
|
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); |
|
int address_base = m_offset * stride + c_offset; |
|
int address_increment = inner_loop_stride * stride; |
|
|
|
for (int i = 0; i < loop_count; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < PARALLEL_LOADS; j++) { |
|
if (c_offset < stride && m_offset < reduction_size) { |
|
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c; |
|
if (z != nullptr) { |
|
tmp += z[address_base]; |
|
} |
|
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp)); |
|
} |
|
m_offset += inner_loop_stride; |
|
address_base += address_increment; |
|
} |
|
} |
|
} |
|
|
|
template<typename T> |
|
__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, |
|
T& sum_dy_xmu, |
|
T* shmem_sum_dy, |
|
T* shmem_sum_dy_xmu) { |
|
|
|
auto address_base = threadIdx.x + threadIdx.y * blockDim.x; |
|
|
|
#pragma unroll |
|
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { |
|
if (threadIdx.y < offset*2) { |
|
shmem_sum_dy[address_base] = sum_dy; |
|
shmem_sum_dy_xmu[address_base] = sum_dy_xmu; |
|
} |
|
__syncthreads(); |
|
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { |
|
auto address = address_base + offset * blockDim.x; |
|
|
|
sum_dy += shmem_sum_dy[address]; |
|
sum_dy_xmu += shmem_sum_dy_xmu[address]; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template < |
|
int PARALLEL_LOADS, |
|
typename scalar_t, |
|
typename accscalar_t, |
|
typename layerscalar_t> |
|
__global__ void batch_norm_backward_reduce_channels_last_kernel( |
|
const scalar_t* __restrict__ input, |
|
const scalar_t* __restrict__ grad_output, |
|
const accscalar_t* __restrict__ mean, |
|
const accscalar_t* __restrict__ inv_std, |
|
accscalar_t* __restrict__ sum_dy_o, |
|
accscalar_t* __restrict__ sum_dy_xmu_o, |
|
layerscalar_t* __restrict__ grad_weight, |
|
layerscalar_t* __restrict__ grad_bias, |
|
volatile accscalar_t* staging_data, |
|
int* semaphores, |
|
const int reduction_size, |
|
const int stride) { |
|
|
|
|
|
accscalar_t sum_dy[PARALLEL_LOADS]; |
|
accscalar_t sum_dy_xmu[PARALLEL_LOADS]; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < PARALLEL_LOADS; i++) { |
|
sum_dy[i] = accscalar_t(0); |
|
sum_dy_xmu[i] = accscalar_t(0); |
|
} |
|
|
|
|
|
|
|
int inner_loop_stride = blockDim.y * gridDim.y; |
|
|
|
|
|
int m_offset = blockIdx.y * blockDim.y + threadIdx.y; |
|
int c_offset = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
if (c_offset >= stride || m_offset >= reduction_size) { |
|
return; |
|
} |
|
|
|
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); |
|
int address_base = m_offset * stride + c_offset; |
|
int address_increment = inner_loop_stride * stride; |
|
|
|
auto r_mean = mean[c_offset]; |
|
auto factor = inv_std[c_offset]; |
|
|
|
for (int i = 0; i < loop_count; i++) { |
|
accscalar_t x_input[PARALLEL_LOADS]; |
|
accscalar_t x_grad_output[PARALLEL_LOADS]; |
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < PARALLEL_LOADS; j++) { |
|
if (c_offset < stride && m_offset < reduction_size) { |
|
x_input[j] = input[address_base]; |
|
x_grad_output[j] = grad_output[address_base]; |
|
} else { |
|
x_input[j] = accscalar_t(0); |
|
x_grad_output[j] = accscalar_t(0); |
|
} |
|
m_offset += inner_loop_stride; |
|
address_base += address_increment; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < PARALLEL_LOADS; j++) { |
|
sum_dy[j] += x_grad_output[j]; |
|
sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); |
|
} |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (int j = 1; j < PARALLEL_LOADS; j++) { |
|
sum_dy[0] += sum_dy[j]; |
|
sum_dy_xmu[0] += sum_dy_xmu[j]; |
|
} |
|
|
|
|
|
auto sum_dy_th = sum_dy[0]; |
|
auto sum_dy_xmu_th = sum_dy_xmu[0]; |
|
|
|
|
|
static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; |
|
static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; |
|
|
|
merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); |
|
|
|
if (gridDim.y > 1) { |
|
volatile accscalar_t* staging_sum_dy = staging_data; |
|
volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; |
|
|
|
address_base = c_offset + blockIdx.y * stride; |
|
|
|
if (threadIdx.y == 0 && c_offset < stride) { |
|
staging_sum_dy[address_base] = sum_dy_th; |
|
staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; |
|
} |
|
|
|
__threadfence(); |
|
__syncthreads(); |
|
|
|
__shared__ bool is_last_block_done; |
|
|
|
if (threadIdx.x == 0 && threadIdx.y == 0) { |
|
int old = atomicAdd(&semaphores[blockIdx.x], 1); |
|
is_last_block_done = (old == (gridDim.y-1)); |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
if (is_last_block_done) { |
|
sum_dy_th = accscalar_t(0.0); |
|
sum_dy_xmu_th = accscalar_t(0.0); |
|
|
|
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { |
|
address_base = c_offset + y * stride; |
|
sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); |
|
sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); |
|
} |
|
|
|
merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); |
|
if (threadIdx.y == 0 && c_offset < stride) { |
|
if (grad_bias != nullptr) { |
|
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th); |
|
} |
|
if (grad_weight != nullptr) { |
|
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor); |
|
} |
|
|
|
|
|
sum_dy_o[c_offset] = sum_dy_th; |
|
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; |
|
} |
|
} |
|
} else { |
|
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { |
|
if (grad_bias != nullptr) { |
|
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th); |
|
} |
|
if (grad_weight != nullptr) { |
|
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor); |
|
} |
|
|
|
|
|
sum_dy_o[c_offset] = sum_dy_th; |
|
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template < |
|
int PARALLEL_LOADS, |
|
typename scalar_t, |
|
typename accscalar_t, |
|
typename layerscalar_t> |
|
__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl( |
|
const scalar_t* __restrict__ grad_output, |
|
const scalar_t* __restrict__ input, |
|
const accscalar_t* __restrict__ mean, |
|
const accscalar_t* __restrict__ inv_std, |
|
const layerscalar_t* __restrict__ weight, |
|
const accscalar_t* __restrict__ sum_dy, |
|
const accscalar_t* __restrict__ sum_dy_xmu, |
|
scalar_t* __restrict__ grad_input, |
|
const accscalar_t norm_fct, |
|
const int reduction_size, |
|
const int stride) { |
|
|
|
|
|
int inner_loop_stride = blockDim.y * gridDim.y; |
|
|
|
|
|
int m_offset = blockIdx.y * blockDim.y + threadIdx.y; |
|
int c_offset = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
if (c_offset >= stride || m_offset >= reduction_size) { |
|
return; |
|
} |
|
|
|
auto m_c = mean[c_offset]; |
|
auto m_dy_c = sum_dy[c_offset] * norm_fct; |
|
auto factor_1_c = inv_std[c_offset]; |
|
auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c; |
|
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct; |
|
|
|
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); |
|
int address_base = m_offset * stride + c_offset; |
|
int address_increment = inner_loop_stride * stride; |
|
|
|
for (int i = 0; i < loop_count; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < PARALLEL_LOADS; j++) { |
|
if (c_offset < stride && m_offset < reduction_size) { |
|
grad_input[address_base] = static_cast<scalar_t>( |
|
(static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c - |
|
(static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c) |
|
* factor_2_c); |
|
} |
|
m_offset += inner_loop_stride; |
|
address_base += address_increment; |
|
} |
|
} |
|
} |
|
|
|
template < |
|
int PARALLEL_LOADS, |
|
typename scalar_t, |
|
typename accscalar_t, |
|
typename layerscalar_t> |
|
__global__ void batch_norm_backward_elemt_channels_last_kernel( |
|
const scalar_t* __restrict__ grad_output, |
|
const scalar_t* __restrict__ input, |
|
const accscalar_t* __restrict__ mean, |
|
const accscalar_t* __restrict__ inv_std, |
|
const layerscalar_t* __restrict__ weight, |
|
const accscalar_t* __restrict__ sum_dy, |
|
const accscalar_t* __restrict__ sum_dy_xmu, |
|
const int* __restrict__ numel, |
|
scalar_t* __restrict__ grad_input, |
|
const int64_t world_size, |
|
const int reduction_size, |
|
const int stride) { |
|
|
|
int64_t total_numel = 0; |
|
for (int i = 0; i < world_size; i++) { |
|
total_numel += numel[i]; |
|
} |
|
|
|
auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel); |
|
batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>( |
|
grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, |
|
grad_input, norm_fct, reduction_size, stride); |
|
} |
|
|
|
template < |
|
int PARALLEL_LOADS, |
|
typename scalar_t, |
|
typename accscalar_t, |
|
typename layerscalar_t> |
|
__global__ void batch_norm_backward_elemt_channels_last_kernel( |
|
const scalar_t* __restrict__ grad_output, |
|
const scalar_t* __restrict__ input, |
|
const accscalar_t* __restrict__ mean, |
|
const accscalar_t* __restrict__ inv_std, |
|
const layerscalar_t* __restrict__ weight, |
|
const accscalar_t* __restrict__ sum_dy, |
|
const accscalar_t* __restrict__ sum_dy_xmu, |
|
scalar_t* __restrict__ grad_input, |
|
const accscalar_t norm_fct, |
|
const int reduction_size, |
|
const int stride) { |
|
batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>( |
|
grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, |
|
grad_input, norm_fct, reduction_size, stride); |
|
} |
|
|
|
template<typename scalar_t, typename VarTransform> |
|
void batch_norm_stats_channels_last_cuda_template( |
|
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
|
|
const auto stride = input.sizes()[1]; |
|
const auto reduction_size = input.numel() / stride; |
|
|
|
resize_output(out_mean, {stride}); |
|
resize_output(out_invstd, {stride}); |
|
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && |
|
out_invstd.sizes()[0]); |
|
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && |
|
out_mean.sizes()[0]); |
|
|
|
dim3 block; |
|
dim3 grid; |
|
flexible_launch_configs(reduction_size, stride, block, grid, true); |
|
|
|
at::Tensor staging_data; |
|
at::Tensor semaphores; |
|
if (grid.y > 1) { |
|
staging_data = at::empty({4*stride*grid.y}, out_mean.options()); |
|
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); |
|
} |
|
|
|
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr; |
|
int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr; |
|
batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( |
|
input.data_ptr<scalar_t>(), |
|
out_mean.data_ptr<accscalar_t>(), |
|
out_invstd.data_ptr<accscalar_t>(), |
|
staging_data_ptr, |
|
semaphores_ptr, |
|
reduction_size, |
|
stride, |
|
epsilon); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
void batch_norm_elemt_channels_last_cuda_template( |
|
const at::Tensor& output, |
|
const at::Tensor& input, |
|
const at::Tensor& weight, |
|
const at::Tensor& shift, |
|
const at::Tensor& mean, |
|
const at::Tensor& inv_std, |
|
const at::optional<at::Tensor>& z = c10::nullopt, |
|
const bool fuse_relu = false) { |
|
const auto stride = input.sizes()[1]; |
|
const auto reduction_size = input.numel() / stride; |
|
|
|
dim3 block; |
|
dim3 grid; |
|
flexible_launch_configs(reduction_size, stride, block, grid); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
const auto second_dtype = weight.defined() ? weight.scalar_type() : |
|
(shift.defined() ? shift.scalar_type() : input.scalar_type()); |
|
|
|
if (input.scalar_type() != second_dtype) { |
|
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), |
|
z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr, |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
weight.defined() ? weight.data_ptr<accscalar_t>() : nullptr, |
|
shift.defined() ? shift.data_ptr<accscalar_t>() : nullptr, |
|
output.data_ptr<scalar_t>(), |
|
reduction_size, |
|
stride, |
|
fuse_relu); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} else { |
|
if (weight.defined()){ |
|
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(), |
|
" is not supported with weight.scalar_type() ", weight.scalar_type()); |
|
} |
|
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), |
|
z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr, |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr, |
|
shift.defined() ? shift.data_ptr<scalar_t>(): nullptr, |
|
output.data_ptr<scalar_t>(), |
|
reduction_size, |
|
stride, |
|
fuse_relu); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} |
|
} |
|
|
|
std::tuple<Tensor, Tensor, Tensor, Tensor> |
|
batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output, |
|
const at::Tensor& input, |
|
const at::Tensor& mean, |
|
const at::Tensor& inv_std, |
|
const at::Tensor& weight, |
|
const bool input_g, const bool weight_g, const bool bias_g) { |
|
const auto stride = input.sizes()[1]; |
|
const auto reduction_size = input.numel() / stride; |
|
|
|
at::Tensor sumn_dy = at::empty({stride}, mean.options()); |
|
at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); |
|
|
|
at::Tensor grad_weight; |
|
at::Tensor grad_bias; |
|
if (weight.defined()) { |
|
grad_weight = at::empty({stride}, weight.options()); |
|
grad_bias = at::empty({stride}, weight.options()); |
|
} else { |
|
|
|
grad_weight = at::empty({0}, mean.options()); |
|
grad_bias = at::empty({0}, mean.options()); |
|
} |
|
|
|
dim3 block; |
|
dim3 grid; |
|
flexible_launch_configs(reduction_size, stride, block, grid, true); |
|
|
|
at::Tensor staging_data; |
|
at::Tensor semaphores; |
|
if (grid.y > 1) { |
|
staging_data = at::empty({2*stride*grid.y}, mean.options()); |
|
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); |
|
} |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
if (weight.defined() && input.scalar_type() != weight.scalar_type()) { |
|
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr; |
|
int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr; |
|
batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), |
|
grad_output.data_ptr<scalar_t>(), |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
sumn_dy.data_ptr<accscalar_t>(), |
|
sum_dy_xmu.data_ptr<accscalar_t>(), |
|
grad_weight.data_ptr<accscalar_t>(), |
|
grad_bias.data_ptr<accscalar_t>(), |
|
staging_data_ptr, |
|
semaphores_ptr, |
|
reduction_size, |
|
stride); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} else { |
|
if (weight.defined()) { |
|
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(), |
|
" is not supported with weight.scalar_type() ", weight.scalar_type()); |
|
} |
|
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr; |
|
int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr; |
|
batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
input.data_ptr<scalar_t>(), |
|
grad_output.data_ptr<scalar_t>(), |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
sumn_dy.data_ptr<accscalar_t>(), |
|
sum_dy_xmu.data_ptr<accscalar_t>(), |
|
weight.defined() ? grad_weight.data_ptr<scalar_t>() : nullptr, |
|
weight.defined() ? grad_bias.data_ptr<scalar_t>() : nullptr, |
|
staging_data_ptr, |
|
semaphores_ptr, |
|
reduction_size, |
|
stride); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} |
|
|
|
return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); |
|
} |
|
|
|
at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( |
|
const at::Tensor& grad_output, |
|
const at::Tensor& input, |
|
const at::Tensor& mean, |
|
const at::Tensor& inv_std, |
|
const at::Tensor& weight, |
|
const at::Tensor& sum_dy, |
|
const at::Tensor& sum_dy_xmu, |
|
const at::Tensor& count) { |
|
const auto stride = input.sizes()[1]; |
|
const auto reduction_size = input.numel() / stride; |
|
|
|
|
|
at::Tensor grad_input = at::empty_like(input); |
|
|
|
dim3 block; |
|
dim3 grid; |
|
flexible_launch_configs(reduction_size, stride, block, grid); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
if (weight.defined() && weight.scalar_type() != input.scalar_type()) { |
|
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
grad_output.data_ptr<scalar_t>(), |
|
input.data_ptr<scalar_t>(), |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
weight.data_ptr<accscalar_t>(), |
|
sum_dy.data_ptr<accscalar_t>(), |
|
sum_dy_xmu.data_ptr<accscalar_t>(), |
|
count.data_ptr<int>(), |
|
grad_input.data_ptr<scalar_t>(), |
|
count.numel(), |
|
reduction_size, |
|
stride); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} else { |
|
if (weight.defined()) { |
|
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(), |
|
" is not supported with weight.scalar_type() ", weight.scalar_type()); |
|
} |
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
grad_output.data_ptr<scalar_t>(), |
|
input.data_ptr<scalar_t>(), |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr, |
|
sum_dy.data_ptr<accscalar_t>(), |
|
sum_dy_xmu.data_ptr<accscalar_t>(), |
|
count.data_ptr<int>(), |
|
grad_input.data_ptr<scalar_t>(), |
|
count.numel(), |
|
reduction_size, |
|
stride); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
}); |
|
} |
|
|
|
return grad_input; |
|
} |
|
|
|
at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( |
|
const at::Tensor& grad_output, |
|
const at::Tensor& input, |
|
const at::Tensor& mean, |
|
const at::Tensor& inv_std, |
|
const at::Tensor& weight, |
|
const at::Tensor& sum_dy, |
|
const at::Tensor& sum_dy_xmu) { |
|
const auto stride = input.sizes()[1]; |
|
const auto reduction_size = input.numel() / stride; |
|
auto norm_fct = 1.0 / reduction_size; |
|
|
|
|
|
at::Tensor grad_input = at::empty_like(input); |
|
|
|
dim3 block; |
|
dim3 grid; |
|
flexible_launch_configs(reduction_size, stride, block, grid); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { |
|
using accscalar_t = at::acc_type<scalar_t, true>; |
|
|
|
if (weight.defined() && weight.scalar_type() != input.scalar_type()) { |
|
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
grad_output.data_ptr<scalar_t>(), |
|
input.data_ptr<scalar_t>(), |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
weight.data_ptr<accscalar_t>(), |
|
sum_dy.data_ptr<accscalar_t>(), |
|
sum_dy_xmu.data_ptr<accscalar_t>(), |
|
grad_input.data_ptr<scalar_t>(), |
|
static_cast<accscalar_t>(norm_fct), |
|
reduction_size, |
|
stride); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} else { |
|
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER> |
|
<<<grid, block, 0, stream>>>( |
|
grad_output.data_ptr<scalar_t>(), |
|
input.data_ptr<scalar_t>(), |
|
mean.data_ptr<accscalar_t>(), |
|
inv_std.data_ptr<accscalar_t>(), |
|
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr, |
|
sum_dy.data_ptr<accscalar_t>(), |
|
sum_dy_xmu.data_ptr<accscalar_t>(), |
|
grad_input.data_ptr<scalar_t>(), |
|
static_cast<accscalar_t>(norm_fct), |
|
reduction_size, |
|
stride); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
}); |
|
|
|
return grad_input; |
|
} |
|
|
|
} } |
|
|