Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from ..version_utils import digit_version | |
from .parrots_wrapper import TORCH_VERSION | |
_torch_version_meshgrid_indexing = ( | |
'parrots' not in TORCH_VERSION | |
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) | |
def torch_meshgrid(*tensors): | |
"""A wrapper of torch.meshgrid to compat different PyTorch versions. | |
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. | |
So we implement a wrapper here to avoid warning when using high-version | |
PyTorch and avoid compatibility issues when using previous versions of | |
PyTorch. | |
Args: | |
tensors (List[Tensor]): List of scalars or 1 dimensional tensors. | |
Returns: | |
Sequence[Tensor]: Sequence of meshgrid tensors. | |
""" | |
if _torch_version_meshgrid_indexing: | |
return torch.meshgrid(*tensors, indexing='ij') | |
else: | |
return torch.meshgrid(*tensors) # Uses indexing='ij' by default | |