JohanDL's picture
adding checkpoints
bbd0fe5
raw
history blame
No virus
1.11 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Union
import numpy as np
import torch
TypeSpec = Union[str, np.dtype, torch.dtype]
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
np.dtype("bool"): torch.bool,
np.dtype("uint8"): torch.uint8,
np.dtype("int8"): torch.int8,
np.dtype("int16"): torch.int16,
np.dtype("int32"): torch.int32,
np.dtype("int64"): torch.int64,
np.dtype("float16"): torch.float16,
np.dtype("float32"): torch.float32,
np.dtype("float64"): torch.float64,
np.dtype("complex64"): torch.complex64,
np.dtype("complex128"): torch.complex128,
}
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
if isinstance(dtype, torch.dtype):
return dtype
if isinstance(dtype, str):
dtype = np.dtype(dtype)
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
return _NUMPY_TO_TORCH_DTYPE[dtype]