File size: 1,808 Bytes
c689089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import importlib
import subprocess
import sys

def install_package(package_name):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])

def check_device():
    # **Check for NVIDIA GPU (CUDA)**
    if torch.cuda.is_available():
        device = torch.device("cuda")  # Use NVIDIA GPU
        backend = "CUDA (NVIDIA)"
        mixed_precision = True  # Use Automatic Mixed Precision (AMP)

    # **If no NVIDIA GPU, check for AMD GPU (DirectML) only in Windows**
    else:
        try:
            # Only try DirectML if the environment is Windows and DirectML is installed
            if "win32" in sys.platform:
                torch_directml = importlib.import_module("torch_directml")
                if torch_directml.device_count() > 0:
                    device = torch_directml.device()  # Use AMD GPU with DirectML
                    backend = "DirectML (AMD)"
                    mixed_precision = False  # No AMP for AMD GPU
                else:
                    raise ImportError  # AMD GPU not found
            else:
                device = torch.device("cpu")
                backend = "CPU"
                mixed_precision = False  # No AMP for CPU

        except ImportError:
            # If DirectML is not installed or AMD GPU not found
            device = torch.device("cpu")
            backend = "CPU"
            mixed_precision = False  # No AMP for CPU

    # Print the chosen device info
    print(f"Training is running on: {backend} ({device})")

    # **Initialize scaler (only for NVIDIA)**
    if mixed_precision:
        scaler = torch.amp.GradScaler()
    else:
        scaler = None  # No scaler needed for AMD/CPU

    return device, backend, scaler

if __name__ == "__main__":
    device, backend, scaler = check_device()