Spaces:
Running
on
A100
Running
on
A100
File size: 433 Bytes
cb92d2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_dtype = torch.float16
if mps_available:
device = torch.device("mps")
torch_dtype = torch.float32
|