File size: 299 Bytes
8b0ae10
570c043
 
 
8b0ae10
570c043
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import importlib


def get_device():
    torch = importlib.import_module('torch')
    device ="cpu"
    if torch.cuda.is_available():
        device = "cuda"

    try:
        if torch.backends.mps.is_available():
            device = "mps"
    except:  # noqa: E722
        pass

    return device