|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define EIGEN_USE_GPU |
|
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ |
|
#include "tensorflow/core/framework/op.h" |
|
#include "tensorflow/core/framework/op_kernel.h" |
|
#include "tensorflow/core/framework/shape_inference.h" |
|
#include <stdio.h> |
|
|
|
using namespace tensorflow; |
|
using namespace tensorflow::shape_inference; |
|
|
|
|
|
|
|
|
|
#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) |
|
|
|
static __host__ __device__ __forceinline__ int floorDiv(int a, int b) |
|
{ |
|
int t = 1 - a / b; |
|
return (a + t * b) / b - t; |
|
} |
|
|
|
|
|
|
|
|
|
template <class T> |
|
struct UpFirDn2DKernelParams |
|
{ |
|
const T* x; |
|
const T* k; |
|
T* y; |
|
|
|
int upx; |
|
int upy; |
|
int downx; |
|
int downy; |
|
int padx0; |
|
int padx1; |
|
int pady0; |
|
int pady1; |
|
|
|
int majorDim; |
|
int inH; |
|
int inW; |
|
int minorDim; |
|
int kernelH; |
|
int kernelW; |
|
int outH; |
|
int outW; |
|
int loopMajor; |
|
int loopX; |
|
}; |
|
|
|
|
|
|
|
|
|
template <class T> |
|
static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p) |
|
{ |
|
|
|
int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; |
|
int outY = minorIdx / p.minorDim; |
|
minorIdx -= outY * p.minorDim; |
|
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; |
|
int majorIdxBase = blockIdx.z * p.loopMajor; |
|
if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) |
|
return; |
|
|
|
|
|
int midY = outY * p.downy + p.upy - 1 - p.pady0; |
|
int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); |
|
int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; |
|
int kernelY = midY + p.kernelH - (inY + 1) * p.upy; |
|
|
|
|
|
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) |
|
for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) |
|
{ |
|
|
|
int midX = outX * p.downx + p.upx - 1 - p.padx0; |
|
int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); |
|
int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; |
|
int kernelX = midX + p.kernelW - (inX + 1) * p.upx; |
|
|
|
|
|
const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; |
|
const T* kp = &p.k[kernelY * p.kernelW + kernelX]; |
|
int xpx = p.minorDim; |
|
int kpx = -p.upx; |
|
int xpy = p.inW * p.minorDim; |
|
int kpy = -p.upy * p.kernelW; |
|
|
|
|
|
float v = 0.0f; |
|
for (int y = 0; y < h; y++) |
|
{ |
|
for (int x = 0; x < w; x++) |
|
{ |
|
v += (float)(*xp) * (float)(*kp); |
|
xp += xpx; |
|
kp += kpx; |
|
} |
|
xp += xpy - w * xpx; |
|
kp += kpy - w * kpx; |
|
} |
|
|
|
|
|
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH> |
|
static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p) |
|
{ |
|
|
|
|
|
const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; |
|
const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; |
|
__shared__ volatile float sk[kernelH][kernelW]; |
|
__shared__ volatile float sx[tileInH][tileInW]; |
|
|
|
|
|
int minorIdx = blockIdx.x; |
|
int tileOutY = minorIdx / p.minorDim; |
|
minorIdx -= tileOutY * p.minorDim; |
|
tileOutY *= tileOutH; |
|
int tileOutXBase = blockIdx.y * p.loopX * tileOutW; |
|
int majorIdxBase = blockIdx.z * p.loopMajor; |
|
if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) |
|
return; |
|
|
|
|
|
for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) |
|
{ |
|
int ky = tapIdx / kernelW; |
|
int kx = tapIdx - ky * kernelW; |
|
float v = 0.0f; |
|
if (kx < p.kernelW & ky < p.kernelH) |
|
v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; |
|
sk[ky][kx] = v; |
|
} |
|
|
|
|
|
for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) |
|
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) |
|
{ |
|
|
|
int tileMidX = tileOutX * downx + upx - 1 - p.padx0; |
|
int tileMidY = tileOutY * downy + upy - 1 - p.pady0; |
|
int tileInX = floorDiv(tileMidX, upx); |
|
int tileInY = floorDiv(tileMidY, upy); |
|
__syncthreads(); |
|
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) |
|
{ |
|
int relInY = inIdx / tileInW; |
|
int relInX = inIdx - relInY * tileInW; |
|
int inX = relInX + tileInX; |
|
int inY = relInY + tileInY; |
|
float v = 0.0f; |
|
if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) |
|
v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; |
|
sx[relInY][relInX] = v; |
|
} |
|
|
|
|
|
__syncthreads(); |
|
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) |
|
{ |
|
int relOutY = outIdx / tileOutW; |
|
int relOutX = outIdx - relOutY * tileOutW; |
|
int outX = relOutX + tileOutX; |
|
int outY = relOutY + tileOutY; |
|
|
|
|
|
int midX = tileMidX + relOutX * downx; |
|
int midY = tileMidY + relOutY * downy; |
|
int inX = floorDiv(midX, upx); |
|
int inY = floorDiv(midY, upy); |
|
int relInX = inX - tileInX; |
|
int relInY = inY - tileInY; |
|
int kernelX = (inX + 1) * upx - midX - 1; |
|
int kernelY = (inY + 1) * upy - midY - 1; |
|
|
|
|
|
float v = 0.0f; |
|
#pragma unroll |
|
for (int y = 0; y < kernelH / upy; y++) |
|
#pragma unroll |
|
for (int x = 0; x < kernelW / upx; x++) |
|
v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; |
|
|
|
|
|
if (outX < p.outW & outY < p.outH) |
|
p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
template <class T> |
|
struct UpFirDn2DOp : public OpKernel |
|
{ |
|
UpFirDn2DKernelParams<T> m_attribs; |
|
|
|
UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) |
|
{ |
|
memset(&m_attribs, 0, sizeof(m_attribs)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); |
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); |
|
OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); |
|
OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); |
|
} |
|
|
|
void Compute(OpKernelContext* ctx) |
|
{ |
|
UpFirDn2DKernelParams<T> p = m_attribs; |
|
cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream(); |
|
|
|
const Tensor& x = ctx->input(0); |
|
const Tensor& k = ctx->input(1); |
|
p.x = x.flat<T>().data(); |
|
p.k = k.flat<T>().data(); |
|
OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); |
|
OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); |
|
OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); |
|
OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); |
|
|
|
p.majorDim = (int)x.dim_size(0); |
|
p.inH = (int)x.dim_size(1); |
|
p.inW = (int)x.dim_size(2); |
|
p.minorDim = (int)x.dim_size(3); |
|
p.kernelH = (int)k.dim_size(0); |
|
p.kernelW = (int)k.dim_size(1); |
|
OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); |
|
|
|
p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; |
|
p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; |
|
OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); |
|
|
|
Tensor* y = NULL; |
|
TensorShape ys; |
|
ys.AddDim(p.majorDim); |
|
ys.AddDim(p.outH); |
|
ys.AddDim(p.outW); |
|
ys.AddDim(p.minorDim); |
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); |
|
p.y = y->flat<T>().data(); |
|
OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); |
|
|
|
|
|
void* cudaKernel = (void*)UpFirDn2DKernel_large<T>; |
|
int tileOutW = -1; |
|
int tileOutH = -1; |
|
|
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 8,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,24, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,20, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,16, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,12, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,8, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
|
|
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; } |
|
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 8,1, 128,8>; tileOutW = 128; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,24, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,20, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,16, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,12, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,8, 32,32>; tileOutW = 32; tileOutH = 32; } |
|
|
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8 >; tileOutW = 32; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8 >; tileOutW = 32; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8 >; tileOutW = 32; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8 >; tileOutW = 32; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 24,1, 64,8 >; tileOutW = 64; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 20,1, 64,8 >; tileOutW = 64; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 16,1, 64,8 >; tileOutW = 64; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 12,1, 64,8 >; tileOutW = 64; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 8,1, 64,8 >; tileOutW = 64; tileOutH = 8; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,24, 32,16>; tileOutW = 32; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,20, 32,16>; tileOutW = 32; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,16, 32,16>; tileOutW = 32; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,12, 32,16>; tileOutW = 32; tileOutH = 16; } |
|
if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,8, 32,16>; tileOutW = 32; tileOutH = 16; } |
|
|
|
|
|
dim3 blockSize; |
|
dim3 gridSize; |
|
if (tileOutW > 0 && tileOutH > 0) |
|
{ |
|
p.loopMajor = (p.majorDim - 1) / 16384 + 1; |
|
p.loopX = 1; |
|
blockSize = dim3(32 * 8, 1, 1); |
|
gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); |
|
} |
|
else |
|
{ |
|
p.loopMajor = (p.majorDim - 1) / 16384 + 1; |
|
p.loopX = 4; |
|
blockSize = dim3(4, 32, 1); |
|
gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); |
|
} |
|
|
|
|
|
void* args[] = {&p}; |
|
OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); |
|
} |
|
}; |
|
|
|
REGISTER_OP("UpFirDn2D") |
|
.Input ("x: T") |
|
.Input ("k: T") |
|
.Output ("y: T") |
|
.Attr ("T: {float, half}") |
|
.Attr ("upx: int = 1") |
|
.Attr ("upy: int = 1") |
|
.Attr ("downx: int = 1") |
|
.Attr ("downy: int = 1") |
|
.Attr ("padx0: int = 0") |
|
.Attr ("padx1: int = 0") |
|
.Attr ("pady0: int = 0") |
|
.Attr ("pady1: int = 0"); |
|
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>); |
|
REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>); |
|
|
|
|
|
|