Spaces:
Build error
Build error
File size: 9,828 Bytes
3d49622 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
#include <ATen/ATen.h>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <vector>
#include "utils/checks.h"
#include "utils/cuda.cuh"
#include "inplace_abn.h"
#include <ATen/cuda/CUDAContext.h>
// Operations for reduce
template<typename T>
struct SumOp {
__device__ SumOp(const T *t, int c, int s)
: tensor(t), chn(c), sp(s) {}
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
return tensor[(batch * chn + plane) * sp + n];
}
const T *tensor;
const int chn;
const int sp;
};
template<typename T>
struct VarOp {
__device__ VarOp(T m, const T *t, int c, int s)
: mean(m), tensor(t), chn(c), sp(s) {}
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
T val = tensor[(batch * chn + plane) * sp + n];
return (val - mean) * (val - mean);
}
const T mean;
const T *tensor;
const int chn;
const int sp;
};
template<typename T>
struct GradOp {
__device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
T _dz = dz[(batch * chn + plane) * sp + n];
return Pair<T>(_dz, _y * _dz);
}
const T weight;
const T bias;
const T *z;
const T *dz;
const int chn;
const int sp;
};
/***********
* mean_var
***********/
template<typename T>
__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
int plane = blockIdx.x;
T norm = T(1) / T(num * sp);
T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
__syncthreads();
T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
if (threadIdx.x == 0) {
mean[plane] = _mean;
var[plane] = _var;
}
}
std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
CHECK_CUDA_INPUT(x);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Prepare output tensors
auto mean = at::empty({chn}, x.options());
auto var = at::empty({chn}, x.options());
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
x.data<scalar_t>(),
mean.data<scalar_t>(),
var.data<scalar_t>(),
num, chn, sp);
}));
return {mean, var};
}
/**********
* forward
**********/
template<typename T>
__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
bool affine, float eps, int num, int chn, int sp) {
int plane = blockIdx.x;
T _mean = mean[plane];
T _var = var[plane];
T _weight = affine ? abs(weight[plane]) + eps : T(1);
T _bias = affine ? bias[plane] : T(0);
T mul = rsqrt(_var + eps) * _weight;
for (int batch = 0; batch < num; ++batch) {
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
T _x = x[(batch * chn + plane) * sp + n];
T _y = (_x - _mean) * mul + _bias;
x[(batch * chn + plane) * sp + n] = _y;
}
}
}
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps) {
CHECK_CUDA_INPUT(x);
CHECK_CUDA_INPUT(mean);
CHECK_CUDA_INPUT(var);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
x.data<scalar_t>(),
mean.data<scalar_t>(),
var.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return x;
}
/***********
* edz_eydz
***********/
template<typename T>
__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
int plane = blockIdx.x;
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
T _bias = affine ? bias[plane] : 0.f;
Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
__syncthreads();
if (threadIdx.x == 0) {
edz[plane] = res.v1;
eydz[plane] = res.v2;
}
}
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
// Extract dimensions
int64_t num, chn, sp;
get_dims(z, num, chn, sp);
auto edz = at::empty({chn}, z.options());
auto eydz = at::empty({chn}, z.options());
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
z.data<scalar_t>(),
dz.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
edz.data<scalar_t>(),
eydz.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return {edz, eydz};
}
/***********
* backward
***********/
template<typename T>
__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
int plane = blockIdx.x;
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
T _bias = affine ? bias[plane] : 0.f;
T _var = var[plane];
T _edz = edz[plane];
T _eydz = eydz[plane];
T _mul = _weight * rsqrt(_var + eps);
T count = T(num * sp);
for (int batch = 0; batch < num; ++batch) {
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
T _dz = dz[(batch * chn + plane) * sp + n];
T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
}
}
}
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
CHECK_CUDA_INPUT(var);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
CHECK_CUDA_INPUT(edz);
CHECK_CUDA_INPUT(eydz);
// Extract dimensions
int64_t num, chn, sp;
get_dims(z, num, chn, sp);
auto dx = at::zeros_like(z);
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
z.data<scalar_t>(),
dz.data<scalar_t>(),
var.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
edz.data<scalar_t>(),
eydz.data<scalar_t>(),
dx.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return dx;
}
/**************
* activations
**************/
template<typename T>
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
// Create thrust pointers
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
auto stream = at::cuda::getCurrentCUDAStream();
thrust::transform_if(thrust::cuda::par.on(stream),
th_dz, th_dz + count, th_z, th_dz,
[slope] __device__ (const T& dz) { return dz * slope; },
[] __device__ (const T& z) { return z < 0; });
thrust::transform_if(thrust::cuda::par.on(stream),
th_z, th_z + count, th_z,
[slope] __device__ (const T& z) { return z / slope; },
[] __device__ (const T& z) { return z < 0; });
}
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
int64_t count = z.numel();
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
}));
}
template<typename T>
inline void elu_backward_impl(T *z, T *dz, int64_t count) {
// Create thrust pointers
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
auto stream = at::cuda::getCurrentCUDAStream();
thrust::transform_if(thrust::cuda::par.on(stream),
th_dz, th_dz + count, th_z, th_z, th_dz,
[] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
[] __device__ (const T& z) { return z < 0; });
thrust::transform_if(thrust::cuda::par.on(stream),
th_z, th_z + count, th_z,
[] __device__ (const T& z) { return log1p(z); },
[] __device__ (const T& z) { return z < 0; });
}
void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
int64_t count = z.numel();
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
}));
}
|