JustinLin610
update
8437114
raw
history blame
No virus
9.64 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
def gen_forward():
kernels = [3, 5, 7, 15, 31, 63, 127, 255]
seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]]
head = """
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "lightconv_cuda.cuh"
std::vector<at::Tensor> lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) {
at::DeviceGuard g(input.device());
const auto minibatch = input.size(0);
const auto numFeatures = input.size(1);
const auto sequenceLength = input.size(2);
const auto numHeads = filters.size(0);
const auto filterSize = filters.size(1);
const auto numFiltersInBlock = numFeatures / numHeads;
const dim3 blocks(minibatch, numFeatures);
auto output = at::zeros_like(input);
auto stream = at::cuda::getCurrentCUDAStream();
"""
sequence_if = """
if (sequenceLength <= {seq}) {{
switch(filterSize) {{
"""
case_k = """
case {k}:
"""
main_block = """
if (padding_l == {pad}) {{
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{
lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t>
<<<blocks, {b_size}, 0, stream>>>(
input.data<scalar_t>(),
filters.data<scalar_t>(),
minibatch,
sequenceLength,
numFeatures,
numFiltersInBlock,
output.data<scalar_t>());
}}));
}} else
"""
bad_padding = """
{
std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl;
}
break;
"""
bad_filter = """
default:
std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl;
}
"""
con_else = """
} else
"""
final_else = """
{
switch(filterSize) {
"""
final_return = """
}
return {output};
}
"""
with open("lightconv_cuda_forward.cu", "w") as forward:
forward.write(head)
for seq in seqs:
forward.write(sequence_if.format(seq=seq))
for k in kernels:
forward.write(case_k.format(k=k))
for pad in [k // 2, k - 1]:
forward.write(main_block.format(k=k, b_size=seq, pad=pad))
forward.write(bad_padding)
forward.write(bad_filter)
forward.write(con_else)
forward.write(final_else)
for k in kernels:
forward.write(case_k.format(k=k))
for pad in [k // 2, k - 1]:
forward.write(main_block.format(k=k, b_size=seq, pad=pad))
forward.write(bad_padding)
forward.write(bad_filter)
forward.write(final_return)
def gen_backward():
head = """
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "lightconv_cuda.cuh"
std::vector<at::Tensor> lightconv_cuda_backward(
at::Tensor gradOutput,
int padding_l,
at::Tensor input,
at::Tensor filters) {
// gradWrtInput
const int minibatch = input.size(0);
const int numFeatures = input.size(1);
const int sequenceLength = input.size(2);
const int numHeads = filters.size(0);
const int filterSize = filters.size(1);
const dim3 gradBlocks(minibatch, numFeatures);
const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads);
const dim3 weightGradSecondpassBlocks(numHeads, filterSize);
const int numFiltersInBlock = numFeatures / numHeads;
auto gradInput = at::zeros_like(input);
auto gradFilters = at::zeros_like(filters);
at::DeviceGuard g(input.device());
auto stream = at::cuda::getCurrentCUDAStream();
switch(filterSize) {
"""
sequence_if = """
if (sequenceLength <= {seq}) {{
"""
case_k = """
case {k}:
"""
main_block = """
if (padding_l == {p}) {{
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{
lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t>
<<<gradBlocks, {b_size}, 0, stream>>>(
gradOutput.data<scalar_t>(),
filters.data<scalar_t>(),
minibatch,
sequenceLength,
numFeatures,
numFiltersInBlock,
gradInput.data<scalar_t>());
"""
weight_grad_short = """
at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat));
lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t>
<<<weightGradFirstpassShortBlocks, {b_size}, 0, stream>>>(
input.data<scalar_t>(),
gradOutput.data<scalar_t>(),
minibatch,
sequenceLength,
numFeatures,
numFiltersInBlock,
numHeads,
tempSumGradFilters.data<float>()
);
lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t>
<<<weightGradSecondpassBlocks, {b_size}, 0, stream>>>(
tempSumGradFilters.data<float>(),
minibatch,
numFiltersInBlock,
gradFilters.data<scalar_t>()
);
}}));
}} else
"""
weight_grad = """
at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat));
lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t>
<<<gradBlocks, {b_size}, 0, stream>>>(
input.data<scalar_t>(),
gradOutput.data<scalar_t>(),
minibatch,
sequenceLength,
numFeatures,
numFiltersInBlock,
tempSumGradFilters.data<float>()
);
lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t>
<<<weightGradSecondpassBlocks, {b_size}, 0, stream>>>(
tempSumGradFilters.data<float>(),
minibatch,
numFiltersInBlock,
gradFilters.data<scalar_t>()
);
}}));
}} else
"""
bad_padding = """
{
std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl;
}
"""
breakout = """
break;
"""
bad_filter = """
default:
std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl;
"""
con_else = """
} else
"""
final_else = """
{
switch(filterSize) {
"""
last_return = """
}
return {gradInput, gradFilters};
}
"""
kernels = [3, 5, 7, 15, 31, 63, 127, 255]
seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]]
thresh = [32, 32, 64, 128, 256, -1, -1, -1]
max_mem = [-1, -1, -1, -1, -1, 192, 96, 64]
with open("lightconv_cuda_backward.cu", "w") as backward:
backward.write(head)
for (k, t, mem) in zip(kernels, thresh, max_mem):
backward.write(case_k.format(k=k))
for seq in seqs:
if (t == -1 or seq <= t) and (mem == -1 or seq < mem):
backward.write(sequence_if.format(seq=seq))
for p in [k // 2, k - 1]:
backward.write(main_block.format(k=k, b_size=seq, p=p))
backward.write(weight_grad_short.format(k=k, b_size=seq, p=p))
backward.write(bad_padding)
else:
for p in [k // 2, k - 1]:
backward.write(main_block.format(k=k, b_size=32, p=p))
backward.write(weight_grad.format(k=k, b_size=32, p=p))
backward.write(bad_padding)
backward.write(breakout)
break
backward.write(con_else)
backward.write(bad_filter)
backward.write(last_return)
if __name__ == "__main__":
gen_forward()
gen_backward()