Fix issue where torch.dtype throws error when converting to dtype
Browse files
util.py
CHANGED
@@ -93,6 +93,8 @@ def parse_device(device: str | torch.device | None) -> torch.device:
|
|
93 |
|
94 |
|
95 |
def into_dtype(dtype: str) -> torch.dtype:
|
|
|
|
|
96 |
if dtype == "float16":
|
97 |
return torch.float16
|
98 |
elif dtype == "bfloat16":
|
|
|
93 |
|
94 |
|
95 |
def into_dtype(dtype: str) -> torch.dtype:
|
96 |
+
if isinstance(dtype, torch.dtype):
|
97 |
+
return dtype
|
98 |
if dtype == "float16":
|
99 |
return torch.float16
|
100 |
elif dtype == "bfloat16":
|