File size: 248 Bytes
c19ca42
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from typing import Optional, Union
import torch


rDevice = Union[torch.device, int]
def get_device(device: Optional[rDevice]=None) -> torch.device:
    if device is None:
        device = torch.dml.current_device()
    return torch.device(device)