/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include "../cuda_utils.cu" #include "dynamicconv_cuda.cuh" #include "dynamicconv_cuda_backward.cu" #include "dynamicconv_cuda_forward.cu" // FS is filter size and kernels are specialized for filter sizes template __global__ void dynamicconv_forward_kernel( const scalar_t* input, const scalar_t* weight, int minibatch, int sequenceLength, int numFeatures, int numFiltersInBlock, int numHeads, scalar_t* output) { assert(blockDim.x == SB); const int tid = threadIdx.x; const int batchIdx = blockIdx.x; const int featureIdx = blockIdx.y; const int head = featureIdx / numFiltersInBlock; const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; scalar_t* outputFeature = &output[IOOffset]; scalar_t filter[FS]; __shared__ scalar_t tempInput[SB + FS]; zeroSharedMem(tempInput); const int numIterations = divUp(sequenceLength, SB); for (int i = 0; i < numIterations; ++i) { __syncthreads(); const int inputOffset = i * SB; load_input_to_shared( inputFeature, inputOffset, sequenceLength, i, numIterations, false, tempInput); __syncthreads(); if (inputOffset + tid < sequenceLength) { #pragma unroll for (int k = 0; k < FS; ++k) { const int filterOffset = batchIdx * numHeads * FS * sequenceLength + head * FS * sequenceLength + k * sequenceLength + i * SB + tid; filter[k] = weight[filterOffset]; } scalar_t out = scalar_t(0.0); #pragma unroll for (int k = 0; k < FS; ++k) { out += filter[k] * tempInput[tid + k]; } outputFeature[inputOffset + tid] = out; } } } template __global__ void dynamicconv_backward_kernel( const scalar_t* gradOutput, // B * C * T const scalar_t* input, // B * C * T const scalar_t* weight, int minibatch, int sequenceLength, int numFeatures, int numFiltersInBlock, int numHeads, scalar_t* gradWeight, scalar_t* gradInput) { // B * H * k * T assert(blockDim.x == SB); // each block operates on a single batch and filter head const int tid = threadIdx.x; const int batchIdx = blockIdx.x; const int headIdx = blockIdx.y; const int chunkIdx = blockIdx.z; const int numChunks = divUp(sequenceLength, SB); const int inputOffset = chunkIdx * SB; // initialize shared memory for output gradient and input __shared__ scalar_t tempGradOutput[SB + FS]; __shared__ scalar_t tempInput[SB + FS]; const int padding = FS - padding_l - 1; zeroSharedMem(tempGradOutput); zeroSharedMem(tempInput); // initialize local filter and weight gradient sum arrays scalar_t tempGradSum[FS]; scalar_t bfilter[FS]; for (int k = 0; k < FS; ++k) { tempGradSum[k] = scalar_t(0.0); int idxOffset = inputOffset + tid + k - padding; if (idxOffset >= 0 && idxOffset < sequenceLength) { int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + headIdx * FS * sequenceLength + (FS - k - 1) * sequenceLength + idxOffset; bfilter[k] = weight[bfilterOffset]; } else { bfilter[k] = scalar_t(0.0); } } // iterate over filter block for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { __syncthreads(); // load input and output gradient for this channel and chunk const int IOOffset = batchIdx * numFeatures * sequenceLength + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; const scalar_t* inputFeature = &input[IOOffset]; const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; scalar_t* gradInputFeature = &gradInput[IOOffset]; load_input_to_shared( gradOutputFeature, inputOffset, sequenceLength, chunkIdx, numChunks, true, tempGradOutput); load_input_to_shared( inputFeature, inputOffset, sequenceLength, chunkIdx, numChunks, true, tempInput); __syncthreads(); // sum input and weight gradients scalar_t out = scalar_t(0.0); #pragma unroll for (int k = 0; k < FS; ++k) { tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; out += bfilter[k] * tempGradOutput[tid + k]; } if (inputOffset + tid < sequenceLength) { gradInputFeature[inputOffset + tid] = out; } } const int gradOffset = batchIdx * numHeads * FS * sequenceLength + headIdx * FS * sequenceLength; scalar_t* gradWeightFeature = &gradWeight[gradOffset]; // write weight gradient if (inputOffset + tid < sequenceLength) { for (int k = 0; k < FS; ++k) { const int outputOffset = k * sequenceLength + inputOffset + tid; gradWeightFeature[outputOffset] = tempGradSum[k]; } } }