8b0ae10 570c043 8b0ae10 570c043
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import importlib def get_device(): torch = importlib.import_module('torch') device ="cpu" if torch.cuda.is_available(): device = "cuda" try: if torch.backends.mps.is_available(): device = "mps" except: # noqa: E722 pass return device