// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #define EIGEN_USE_GPU #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include using namespace tensorflow; using namespace tensorflow::shape_inference; //------------------------------------------------------------------------ // Helpers. #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) static __host__ __device__ __forceinline__ int floorDiv(int a, int b) { int t = 1 - a / b; return (a + t * b) / b - t; } //------------------------------------------------------------------------ // CUDA kernel params. template struct UpFirDn2DKernelParams { const T* x; // [majorDim, inH, inW, minorDim] const T* k; // [kernelH, kernelW] T* y; // [majorDim, outH, outW, minorDim] int upx; int upy; int downx; int downy; int padx0; int padx1; int pady0; int pady1; int majorDim; int inH; int inW; int minorDim; int kernelH; int kernelW; int outH; int outW; int loopMajor; int loopX; }; //------------------------------------------------------------------------ // General CUDA implementation for large filter kernels. template static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) { // Calculate thread index. int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; int outY = minorIdx / p.minorDim; minorIdx -= outY * p.minorDim; int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; int majorIdxBase = blockIdx.z * p.loopMajor; if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) return; // Setup Y receptive field. int midY = outY * p.downy + p.upy - 1 - p.pady0; int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; int kernelY = midY + p.kernelH - (inY + 1) * p.upy; // Loop over majorDim and outX. for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) { // Setup X receptive field. int midX = outX * p.downx + p.upx - 1 - p.padx0; int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; int kernelX = midX + p.kernelW - (inX + 1) * p.upx; // Initialize pointers. const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; const T* kp = &p.k[kernelY * p.kernelW + kernelX]; int xpx = p.minorDim; int kpx = -p.upx; int xpy = p.inW * p.minorDim; int kpy = -p.upy * p.kernelW; // Inner loop. float v = 0.0f; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += (float)(*xp) * (float)(*kp); xp += xpx; kp += kpx; } xp += xpy - w * xpx; kp += kpy - w * kpx; } // Store result. p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; } } //------------------------------------------------------------------------ // Specialized CUDA implementation for small filter kernels. template static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) { //assert(kernelW % upx == 0); //assert(kernelH % upy == 0); const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; __shared__ volatile float sk[kernelH][kernelW]; __shared__ volatile float sx[tileInH][tileInW]; // Calculate tile index. int minorIdx = blockIdx.x; int tileOutY = minorIdx / p.minorDim; minorIdx -= tileOutY * p.minorDim; tileOutY *= tileOutH; int tileOutXBase = blockIdx.y * p.loopX * tileOutW; int majorIdxBase = blockIdx.z * p.loopMajor; if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) return; // Load filter kernel (flipped). for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) { int ky = tapIdx / kernelW; int kx = tapIdx - ky * kernelW; float v = 0.0f; if (kx < p.kernelW & ky < p.kernelH) v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; sk[ky][kx] = v; } // Loop over majorDim and outX. for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) { // Load input pixels. int tileMidX = tileOutX * downx + upx - 1 - p.padx0; int tileMidY = tileOutY * downy + upy - 1 - p.pady0; int tileInX = floorDiv(tileMidX, upx); int tileInY = floorDiv(tileMidY, upy); __syncthreads(); for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) { int relInY = inIdx / tileInW; int relInX = inIdx - relInY * tileInW; int inX = relInX + tileInX; int inY = relInY + tileInY; float v = 0.0f; if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; sx[relInY][relInX] = v; } // Loop over output pixels. __syncthreads(); for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) { int relOutY = outIdx / tileOutW; int relOutX = outIdx - relOutY * tileOutW; int outX = relOutX + tileOutX; int outY = relOutY + tileOutY; // Setup receptive field. int midX = tileMidX + relOutX * downx; int midY = tileMidY + relOutY * downy; int inX = floorDiv(midX, upx); int inY = floorDiv(midY, upy); int relInX = inX - tileInX; int relInY = inY - tileInY; int kernelX = (inX + 1) * upx - midX - 1; // flipped int kernelY = (inY + 1) * upy - midY - 1; // flipped // Inner loop. float v = 0.0f; #pragma unroll for (int y = 0; y < kernelH / upy; y++) #pragma unroll for (int x = 0; x < kernelW / upx; x++) v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; // Store result. if (outX < p.outW & outY < p.outH) p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; } } } //------------------------------------------------------------------------ // TensorFlow op. template struct UpFirDn2DOp : public OpKernel { UpFirDn2DKernelParams m_attribs; UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) { memset(&m_attribs, 0, sizeof(m_attribs)); OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); } void Compute(OpKernelContext* ctx) { UpFirDn2DKernelParams p = m_attribs; cudaStream_t stream = ctx->eigen_device().stream(); const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] const Tensor& k = ctx->input(1); // [kernelH, kernelW] p.x = x.flat().data(); p.k = k.flat().data(); OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); p.majorDim = (int)x.dim_size(0); p.inH = (int)x.dim_size(1); p.inW = (int)x.dim_size(2); p.minorDim = (int)x.dim_size(3); p.kernelH = (int)k.dim_size(0); p.kernelW = (int)k.dim_size(1); OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); Tensor* y = NULL; // [majorDim, outH, outW, minorDim] TensorShape ys; ys.AddDim(p.majorDim); ys.AddDim(p.outH); ys.AddDim(p.outW); ys.AddDim(p.minorDim); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); p.y = y->flat().data(); OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); // Choose CUDA kernel to use. void* cudaKernel = (void*)UpFirDn2DKernel_large; int tileOutW = -1; int tileOutH = -1; if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } // Choose launch params. dim3 blockSize; dim3 gridSize; if (tileOutW > 0 && tileOutH > 0) // small { p.loopMajor = (p.majorDim - 1) / 16384 + 1; p.loopX = 1; blockSize = dim3(32 * 8, 1, 1); gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); } else // large { p.loopMajor = (p.majorDim - 1) / 16384 + 1; p.loopX = 4; blockSize = dim3(4, 32, 1); gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); } // Launch CUDA kernel. void* args[] = {&p}; OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); } }; REGISTER_OP("UpFirDn2D") .Input ("x: T") .Input ("k: T") .Output ("y: T") .Attr ("T: {float, half}") .Attr ("upx: int = 1") .Attr ("upy: int = 1") .Attr ("downx: int = 1") .Attr ("downy: int = 1") .Attr ("padx0: int = 0") .Attr ("padx1: int = 0") .Attr ("pady0: int = 0") .Attr ("pady1: int = 0"); REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); //------------------------------------------------------------------------