Safetensors
aredden commited on
Commit
fb7df61
·
1 Parent(s): ac049be

Fix issue where torch.dtype throws error when converting to dtype

Browse files
Files changed (1) hide show
  1. util.py +2 -0
util.py CHANGED
@@ -93,6 +93,8 @@ def parse_device(device: str | torch.device | None) -> torch.device:
93
 
94
 
95
  def into_dtype(dtype: str) -> torch.dtype:
 
 
96
  if dtype == "float16":
97
  return torch.float16
98
  elif dtype == "bfloat16":
 
93
 
94
 
95
  def into_dtype(dtype: str) -> torch.dtype:
96
+ if isinstance(dtype, torch.dtype):
97
+ return dtype
98
  if dtype == "float16":
99
  return torch.float16
100
  elif dtype == "bfloat16":