# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from functorch._C import dim tree_flatten = dim.tree_flatten def tree_map(fn, tree): vs, unflatten = tree_flatten(tree) return unflatten(fn(v) for v in vs)