radames's picture
new base
cb92d2b
raw
history blame
433 Bytes
import torch
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_dtype = torch.float16
if mps_available:
device = torch.device("mps")
torch_dtype = torch.float32