#include #include #include #include template __device__ __forceinline__ scalar_t clamp(const scalar_t v, const bound_t lo, const bound_t hi) { return min(max(v, lo), hi); } template __global__ void total_variation_add_grad_cuda_kernel( const scalar_t* __restrict__ param, scalar_t* __restrict__ grad, float wx, float wy, float wz, const size_t sz_i, const size_t sz_j, const size_t sz_k, const size_t N) { const size_t index = blockIdx.x * blockDim.x + threadIdx.x; if(index<<>>( param.data(), grad.data(), wx, wy, wz, sz_i, sz_j, sz_k, N); })); } else { AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] { total_variation_add_grad_cuda_kernel<<>>( param.data(), grad.data(), wx, wy, wz, sz_i, sz_j, sz_k, N); })); } }