| | |
| | #pragma once |
| | #include <nanobind/nanobind.h> |
| |
|
| | #include "mlx/array.h" |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| |
|
| | void tree_visit( |
| | const std::vector<nb::object>& trees, |
| | std::function<void(const std::vector<nb::object>&)> visitor); |
| | void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor); |
| |
|
| | nb::object tree_map( |
| | const std::vector<nb::object>& trees, |
| | std::function<nb::object(const std::vector<nb::object>&)> transform); |
| |
|
| | nb::object tree_map( |
| | nb::object tree, |
| | std::function<nb::object(nb::handle)> transform); |
| |
|
| | void tree_visit_update( |
| | nb::object tree, |
| | std::function<nb::object(nb::handle)> visitor); |
| |
|
| | |
| | |
| | |
| | void tree_fill(nb::object& tree, const std::vector<mx::array>& values); |
| |
|
| | |
| | |
| | |
| | |
| | void tree_replace( |
| | nb::object& tree, |
| | const std::vector<mx::array>& src, |
| | const std::vector<mx::array>& dst); |
| |
|
| | |
| | |
| | |
| | |
| | std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true); |
| |
|
| | |
| | |
| | |
| | nb::object tree_unflatten( |
| | nb::object tree, |
| | const std::vector<mx::array>& values, |
| | int index = 0); |
| |
|
| | std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure( |
| | nb::object tree, |
| | bool strict = true); |
| |
|
| | nb::object tree_unflatten_from_structure( |
| | nb::object structure, |
| | const std::vector<mx::array>& values, |
| | int index = 0); |
| |
|