| /** | |
| * Copyright 2017-present, Facebook, Inc. | |
| * All rights reserved. | |
| * | |
| * This source code is licensed under the license found in the | |
| * LICENSE file in the root directory of this source tree. | |
| */ | |
| namespace { | |
| void alignmentTrainCUDA( | |
| const torch::Tensor& p_choose, | |
| torch::Tensor& alpha, | |
| float eps) { | |
| CHECK_INPUT(p_choose); | |
| CHECK_INPUT(alpha); | |
| alignmentTrainCUDAWrapper(p_choose, alpha, eps); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def( | |
| "alignment_train_cuda", | |
| &alignmentTrainCUDA, | |
| "expected_alignment_from_p_choose (CUDA)"); | |
| } | |
| } // namespace | |