Fix issue where torch.dtype throws error when converting to dtype
Browse files
util.py
CHANGED
|
@@ -93,6 +93,8 @@ def parse_device(device: str | torch.device | None) -> torch.device:
|
|
| 93 |
|
| 94 |
|
| 95 |
def into_dtype(dtype: str) -> torch.dtype:
|
|
|
|
|
|
|
| 96 |
if dtype == "float16":
|
| 97 |
return torch.float16
|
| 98 |
elif dtype == "bfloat16":
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
def into_dtype(dtype: str) -> torch.dtype:
|
| 96 |
+
if isinstance(dtype, torch.dtype):
|
| 97 |
+
return dtype
|
| 98 |
if dtype == "float16":
|
| 99 |
return torch.float16
|
| 100 |
elif dtype == "bfloat16":
|