namespace at { namespace native { | |
void _fused_adam_cuda_impl_( | |
at::TensorList params, | |
at::TensorList grads, | |
at::TensorList exp_avgs, | |
at::TensorList exp_avg_sqs, | |
at::TensorList state_steps, | |
const double lr, | |
const double beta1, | |
const double beta2, | |
const double weight_decay, | |
const double eps, | |
const bool amsgrad, | |
const bool maximize, | |
const c10::optional<at::Tensor>& grad_scale, | |
const c10::optional<at::Tensor>& found_inf | |
); | |
} } // namespace at::native | |