Rasm / dnnlib /tflib /ops /upfirdn_2d.cu
JayR7's picture
First commit
baa26ad
raw
history blame
21.7 kB
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#define EIGEN_USE_GPU
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include <stdio.h>
using namespace tensorflow;
using namespace tensorflow::shape_inference;
//------------------------------------------------------------------------
// Helpers.
#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
{
int t = 1 - a / b;
return (a + t * b) / b - t;
}
//------------------------------------------------------------------------
// CUDA kernel params.
template <class T>
struct UpFirDn2DKernelParams
{
const T* x; // [majorDim, inH, inW, minorDim]
const T* k; // [kernelH, kernelW]
T* y; // [majorDim, outH, outW, minorDim]
int upx;
int upy;
int downx;
int downy;
int padx0;
int padx1;
int pady0;
int pady1;
int majorDim;
int inH;
int inW;
int minorDim;
int kernelH;
int kernelW;
int outH;
int outW;
int loopMajor;
int loopX;
};
//------------------------------------------------------------------------
// General CUDA implementation for large filter kernels.
template <class T>
static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
{
// Calculate thread index.
int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorIdx / p.minorDim;
minorIdx -= outY * p.minorDim;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorIdxBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
return;
// Setup Y receptive field.
int midY = outY * p.downy + p.upy - 1 - p.pady0;
int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
// Loop over majorDim and outX.
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
{
// Setup X receptive field.
int midX = outX * p.downx + p.upx - 1 - p.padx0;
int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
// Initialize pointers.
const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
const T* kp = &p.k[kernelY * p.kernelW + kernelX];
int xpx = p.minorDim;
int kpx = -p.upx;
int xpy = p.inW * p.minorDim;
int kpy = -p.upy * p.kernelW;
// Inner loop.
float v = 0.0f;
for (int y = 0; y < h; y++)
{
for (int x = 0; x < w; x++)
{
v += (float)(*xp) * (float)(*kp);
xp += xpx;
kp += kpx;
}
xp += xpy - w * xpx;
kp += kpy - w * kpx;
}
// Store result.
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
}
}
//------------------------------------------------------------------------
// Specialized CUDA implementation for small filter kernels.
template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
{
//assert(kernelW % upx == 0);
//assert(kernelH % upy == 0);
const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
__shared__ volatile float sk[kernelH][kernelW];
__shared__ volatile float sx[tileInH][tileInW];
// Calculate tile index.
int minorIdx = blockIdx.x;
int tileOutY = minorIdx / p.minorDim;
minorIdx -= tileOutY * p.minorDim;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorIdxBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
return;
// Load filter kernel (flipped).
for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
{
int ky = tapIdx / kernelW;
int kx = tapIdx - ky * kernelW;
float v = 0.0f;
if (kx < p.kernelW & ky < p.kernelH)
v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
sk[ky][kx] = v;
}
// Loop over majorDim and outX.
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
{
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
int tileInX = floorDiv(tileMidX, upx);
int tileInY = floorDiv(tileMidY, upy);
__syncthreads();
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
{
int relInY = inIdx / tileInW;
int relInX = inIdx - relInY * tileInW;
int inX = relInX + tileInX;
int inY = relInY + tileInY;
float v = 0.0f;
if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
sx[relInY][relInX] = v;
}
// Loop over output pixels.
__syncthreads();
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
{
int relOutY = outIdx / tileOutW;
int relOutX = outIdx - relOutY * tileOutW;
int outX = relOutX + tileOutX;
int outY = relOutY + tileOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floorDiv(midX, upx);
int inY = floorDiv(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int kernelX = (inX + 1) * upx - midX - 1; // flipped
int kernelY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
float v = 0.0f;
#pragma unroll
for (int y = 0; y < kernelH / upy; y++)
#pragma unroll
for (int x = 0; x < kernelW / upx; x++)
v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
// Store result.
if (outX < p.outW & outY < p.outH)
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
}
}
}
//------------------------------------------------------------------------
// TensorFlow op.
template <class T>
struct UpFirDn2DOp : public OpKernel
{
UpFirDn2DKernelParams<T> m_attribs;
UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
{
memset(&m_attribs, 0, sizeof(m_attribs));
OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
}
void Compute(OpKernelContext* ctx)
{
UpFirDn2DKernelParams<T> p = m_attribs;
cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
const Tensor& k = ctx->input(1); // [kernelH, kernelW]
p.x = x.flat<T>().data();
p.k = k.flat<T>().data();
OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
p.majorDim = (int)x.dim_size(0);
p.inH = (int)x.dim_size(1);
p.inW = (int)x.dim_size(2);
p.minorDim = (int)x.dim_size(3);
p.kernelH = (int)k.dim_size(0);
p.kernelW = (int)k.dim_size(1);
OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
TensorShape ys;
ys.AddDim(p.majorDim);
ys.AddDim(p.outH);
ys.AddDim(p.outW);
ys.AddDim(p.minorDim);
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
p.y = y->flat<T>().data();
OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
// Choose CUDA kernel to use.
void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
int tileOutW = -1;
int tileOutH = -1;
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 8,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,24, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,20, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,16, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,12, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,8, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; }
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 8,1, 128,8>; tileOutW = 128; tileOutH = 8; }
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,24, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,20, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,16, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,12, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,8, 32,32>; tileOutW = 32; tileOutH = 32; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8 >; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8 >; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8 >; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8 >; tileOutW = 32; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 24,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 20,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 16,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 12,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 8,1, 64,8 >; tileOutW = 64; tileOutH = 8; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,24, 32,16>; tileOutW = 32; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,20, 32,16>; tileOutW = 32; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,16, 32,16>; tileOutW = 32; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,12, 32,16>; tileOutW = 32; tileOutH = 16; }
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,8, 32,16>; tileOutW = 32; tileOutH = 16; }
// Choose launch params.
dim3 blockSize;
dim3 gridSize;
if (tileOutW > 0 && tileOutH > 0) // small
{
p.loopMajor = (p.majorDim - 1) / 16384 + 1;
p.loopX = 1;
blockSize = dim3(32 * 8, 1, 1);
gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
}
else // large
{
p.loopMajor = (p.majorDim - 1) / 16384 + 1;
p.loopX = 4;
blockSize = dim3(4, 32, 1);
gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
}
// Launch CUDA kernel.
void* args[] = {&p};
OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
}
};
REGISTER_OP("UpFirDn2D")
.Input ("x: T")
.Input ("k: T")
.Output ("y: T")
.Attr ("T: {float, half}")
.Attr ("upx: int = 1")
.Attr ("upy: int = 1")
.Attr ("downx: int = 1")
.Attr ("downy: int = 1")
.Attr ("padx0: int = 0")
.Attr ("padx1: int = 0")
.Attr ("pady0: int = 0")
.Attr ("pady1: int = 0");
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
//------------------------------------------------------------------------