Spaces:
Paused
Paused
// CUDA forward declarations | |
void adam_upd_cuda( | |
torch::Tensor param, | |
torch::Tensor grad, | |
torch::Tensor exp_avg, | |
torch::Tensor exp_avg_sq, | |
int step, float beta1, float beta2, float lr, float eps); | |
void masked_adam_upd_cuda( | |
torch::Tensor param, | |
torch::Tensor grad, | |
torch::Tensor exp_avg, | |
torch::Tensor exp_avg_sq, | |
int step, float beta1, float beta2, float lr, float eps); | |
void adam_upd_with_perlr_cuda( | |
torch::Tensor param, | |
torch::Tensor grad, | |
torch::Tensor exp_avg, | |
torch::Tensor exp_avg_sq, | |
torch::Tensor perlr, | |
int step, float beta1, float beta2, float lr, float eps); | |
// C++ interface | |
void adam_upd( | |
torch::Tensor param, | |
torch::Tensor grad, | |
torch::Tensor exp_avg, | |
torch::Tensor exp_avg_sq, | |
int step, float beta1, float beta2, float lr, float eps) { | |
CHECK_INPUT(param); | |
CHECK_INPUT(grad); | |
CHECK_INPUT(exp_avg); | |
CHECK_INPUT(exp_avg_sq); | |
adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, | |
step, beta1, beta2, lr, eps); | |
} | |
void masked_adam_upd( | |
torch::Tensor param, | |
torch::Tensor grad, | |
torch::Tensor exp_avg, | |
torch::Tensor exp_avg_sq, | |
int step, float beta1, float beta2, float lr, float eps) { | |
CHECK_INPUT(param); | |
CHECK_INPUT(grad); | |
CHECK_INPUT(exp_avg); | |
CHECK_INPUT(exp_avg_sq); | |
masked_adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, | |
step, beta1, beta2, lr, eps); | |
} | |
void adam_upd_with_perlr( | |
torch::Tensor param, | |
torch::Tensor grad, | |
torch::Tensor exp_avg, | |
torch::Tensor exp_avg_sq, | |
torch::Tensor perlr, | |
int step, float beta1, float beta2, float lr, float eps) { | |
CHECK_INPUT(param); | |
CHECK_INPUT(grad); | |
CHECK_INPUT(exp_avg); | |
CHECK_INPUT(exp_avg_sq); | |
adam_upd_with_perlr_cuda(param, grad, exp_avg, exp_avg_sq, perlr, | |
step, beta1, beta2, lr, eps); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("adam_upd", &adam_upd, | |
"Adam update"); | |
m.def("masked_adam_upd", &masked_adam_upd, | |
"Adam update ignoring zero grad"); | |
m.def("adam_upd_with_perlr", &adam_upd_with_perlr, | |
"Adam update ignoring zero grad with per-voxel lr"); | |
} | |