Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
def gen_forward(): | |
kernels = [3, 5, 7, 15, 31, 63, 127, 255] | |
seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] | |
head = """ | |
/** | |
* Copyright (c) Facebook, Inc. and its affiliates. | |
* | |
* This source code is licensed under the MIT license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
#include "lightconv_cuda.cuh" | |
std::vector<at::Tensor> lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) { | |
at::DeviceGuard g(input.device()); | |
const auto minibatch = input.size(0); | |
const auto numFeatures = input.size(1); | |
const auto sequenceLength = input.size(2); | |
const auto numHeads = filters.size(0); | |
const auto filterSize = filters.size(1); | |
const auto numFiltersInBlock = numFeatures / numHeads; | |
const dim3 blocks(minibatch, numFeatures); | |
auto output = at::zeros_like(input); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
""" | |
sequence_if = """ | |
if (sequenceLength <= {seq}) {{ | |
switch(filterSize) {{ | |
""" | |
case_k = """ | |
case {k}: | |
""" | |
main_block = """ | |
if (padding_l == {pad}) {{ | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{ | |
lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> | |
<<<blocks, {b_size}, 0, stream>>>( | |
input.data<scalar_t>(), | |
filters.data<scalar_t>(), | |
minibatch, | |
sequenceLength, | |
numFeatures, | |
numFiltersInBlock, | |
output.data<scalar_t>()); | |
}})); | |
}} else | |
""" | |
bad_padding = """ | |
{ | |
std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; | |
} | |
break; | |
""" | |
bad_filter = """ | |
default: | |
std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; | |
} | |
""" | |
con_else = """ | |
} else | |
""" | |
final_else = """ | |
{ | |
switch(filterSize) { | |
""" | |
final_return = """ | |
} | |
return {output}; | |
} | |
""" | |
with open("lightconv_cuda_forward.cu", "w") as forward: | |
forward.write(head) | |
for seq in seqs: | |
forward.write(sequence_if.format(seq=seq)) | |
for k in kernels: | |
forward.write(case_k.format(k=k)) | |
for pad in [k // 2, k - 1]: | |
forward.write(main_block.format(k=k, b_size=seq, pad=pad)) | |
forward.write(bad_padding) | |
forward.write(bad_filter) | |
forward.write(con_else) | |
forward.write(final_else) | |
for k in kernels: | |
forward.write(case_k.format(k=k)) | |
for pad in [k // 2, k - 1]: | |
forward.write(main_block.format(k=k, b_size=seq, pad=pad)) | |
forward.write(bad_padding) | |
forward.write(bad_filter) | |
forward.write(final_return) | |
def gen_backward(): | |
head = """ | |
/** | |
* Copyright (c) Facebook, Inc. and its affiliates. | |
* | |
* This source code is licensed under the MIT license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
#include "lightconv_cuda.cuh" | |
std::vector<at::Tensor> lightconv_cuda_backward( | |
at::Tensor gradOutput, | |
int padding_l, | |
at::Tensor input, | |
at::Tensor filters) { | |
// gradWrtInput | |
const int minibatch = input.size(0); | |
const int numFeatures = input.size(1); | |
const int sequenceLength = input.size(2); | |
const int numHeads = filters.size(0); | |
const int filterSize = filters.size(1); | |
const dim3 gradBlocks(minibatch, numFeatures); | |
const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads); | |
const dim3 weightGradSecondpassBlocks(numHeads, filterSize); | |
const int numFiltersInBlock = numFeatures / numHeads; | |
auto gradInput = at::zeros_like(input); | |
auto gradFilters = at::zeros_like(filters); | |
at::DeviceGuard g(input.device()); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
switch(filterSize) { | |
""" | |
sequence_if = """ | |
if (sequenceLength <= {seq}) {{ | |
""" | |
case_k = """ | |
case {k}: | |
""" | |
main_block = """ | |
if (padding_l == {p}) {{ | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{ | |
lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t> | |
<<<gradBlocks, {b_size}, 0, stream>>>( | |
gradOutput.data<scalar_t>(), | |
filters.data<scalar_t>(), | |
minibatch, | |
sequenceLength, | |
numFeatures, | |
numFiltersInBlock, | |
gradInput.data<scalar_t>()); | |
""" | |
weight_grad_short = """ | |
at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat)); | |
lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t> | |
<<<weightGradFirstpassShortBlocks, {b_size}, 0, stream>>>( | |
input.data<scalar_t>(), | |
gradOutput.data<scalar_t>(), | |
minibatch, | |
sequenceLength, | |
numFeatures, | |
numFiltersInBlock, | |
numHeads, | |
tempSumGradFilters.data<float>() | |
); | |
lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t> | |
<<<weightGradSecondpassBlocks, {b_size}, 0, stream>>>( | |
tempSumGradFilters.data<float>(), | |
minibatch, | |
numFiltersInBlock, | |
gradFilters.data<scalar_t>() | |
); | |
}})); | |
}} else | |
""" | |
weight_grad = """ | |
at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat)); | |
lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t> | |
<<<gradBlocks, {b_size}, 0, stream>>>( | |
input.data<scalar_t>(), | |
gradOutput.data<scalar_t>(), | |
minibatch, | |
sequenceLength, | |
numFeatures, | |
numFiltersInBlock, | |
tempSumGradFilters.data<float>() | |
); | |
lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t> | |
<<<weightGradSecondpassBlocks, {b_size}, 0, stream>>>( | |
tempSumGradFilters.data<float>(), | |
minibatch, | |
numFiltersInBlock, | |
gradFilters.data<scalar_t>() | |
); | |
}})); | |
}} else | |
""" | |
bad_padding = """ | |
{ | |
std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; | |
} | |
""" | |
breakout = """ | |
break; | |
""" | |
bad_filter = """ | |
default: | |
std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; | |
""" | |
con_else = """ | |
} else | |
""" | |
final_else = """ | |
{ | |
switch(filterSize) { | |
""" | |
last_return = """ | |
} | |
return {gradInput, gradFilters}; | |
} | |
""" | |
kernels = [3, 5, 7, 15, 31, 63, 127, 255] | |
seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] | |
thresh = [32, 32, 64, 128, 256, -1, -1, -1] | |
max_mem = [-1, -1, -1, -1, -1, 192, 96, 64] | |
with open("lightconv_cuda_backward.cu", "w") as backward: | |
backward.write(head) | |
for (k, t, mem) in zip(kernels, thresh, max_mem): | |
backward.write(case_k.format(k=k)) | |
for seq in seqs: | |
if (t == -1 or seq <= t) and (mem == -1 or seq < mem): | |
backward.write(sequence_if.format(seq=seq)) | |
for p in [k // 2, k - 1]: | |
backward.write(main_block.format(k=k, b_size=seq, p=p)) | |
backward.write(weight_grad_short.format(k=k, b_size=seq, p=p)) | |
backward.write(bad_padding) | |
else: | |
for p in [k // 2, k - 1]: | |
backward.write(main_block.format(k=k, b_size=32, p=p)) | |
backward.write(weight_grad.format(k=k, b_size=32, p=p)) | |
backward.write(bad_padding) | |
backward.write(breakout) | |
break | |
backward.write(con_else) | |
backward.write(bad_filter) | |
backward.write(last_return) | |
if __name__ == "__main__": | |
gen_forward() | |
gen_backward() | |