Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, Union | |
import numpy as np | |
import torch | |
TypeSpec = Union[str, np.dtype, torch.dtype] | |
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { | |
np.dtype("bool"): torch.bool, | |
np.dtype("uint8"): torch.uint8, | |
np.dtype("int8"): torch.int8, | |
np.dtype("int16"): torch.int16, | |
np.dtype("int32"): torch.int32, | |
np.dtype("int64"): torch.int64, | |
np.dtype("float16"): torch.float16, | |
np.dtype("float32"): torch.float32, | |
np.dtype("float64"): torch.float64, | |
np.dtype("complex64"): torch.complex64, | |
np.dtype("complex128"): torch.complex128, | |
} | |
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: | |
if isinstance(dtype, torch.dtype): | |
return dtype | |
if isinstance(dtype, str): | |
dtype = np.dtype(dtype) | |
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" | |
return _NUMPY_TO_TORCH_DTYPE[dtype] | |