Spaces:
Runtime error
Runtime error
| /* | |
| Copyright (c) Microsoft Corporation. | |
| Licensed under the MIT License. | |
| */ | |
| /* | |
| CPP Binding for CUDA OP | |
| */ | |
| // CUDA forward declarations | |
| torch::Tensor ngram_repeat_block_cuda_forward( | |
| torch::Tensor tokens, | |
| torch::Tensor lprobs, | |
| int bsz, | |
| int step, | |
| int beam_size, | |
| int no_repeat_ngram_size); | |
| // Input check and call to CUDA OP | |
| // Backward method not required | |
| torch::Tensor ngram_repeat_block_forward( | |
| torch::Tensor tokens, | |
| torch::Tensor lprobs, | |
| int bsz, | |
| int step, | |
| int beam_size, | |
| int no_repeat_ngram_size) { | |
| CHECK_INPUT(tokens); | |
| CHECK_INPUT(lprobs); | |
| assert(bsz > 0); | |
| assert(step >= 0); | |
| assert(beam_size > 0); | |
| assert(no_repeat_ngram_size > 0); | |
| return ngram_repeat_block_cuda_forward( | |
| tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def( | |
| "forward", | |
| &ngram_repeat_block_forward, | |
| "No Repeat Ngram Block forward (CUDA)"); | |
| } | |