3v324v23's picture
initial
c689089
import torch
import importlib
import subprocess
import sys
def install_package(package_name):
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
def check_device():
# **Check for NVIDIA GPU (CUDA)**
if torch.cuda.is_available():
device = torch.device("cuda") # Use NVIDIA GPU
backend = "CUDA (NVIDIA)"
mixed_precision = True # Use Automatic Mixed Precision (AMP)
# **If no NVIDIA GPU, check for AMD GPU (DirectML) only in Windows**
else:
try:
# Only try DirectML if the environment is Windows and DirectML is installed
if "win32" in sys.platform:
torch_directml = importlib.import_module("torch_directml")
if torch_directml.device_count() > 0:
device = torch_directml.device() # Use AMD GPU with DirectML
backend = "DirectML (AMD)"
mixed_precision = False # No AMP for AMD GPU
else:
raise ImportError # AMD GPU not found
else:
device = torch.device("cpu")
backend = "CPU"
mixed_precision = False # No AMP for CPU
except ImportError:
# If DirectML is not installed or AMD GPU not found
device = torch.device("cpu")
backend = "CPU"
mixed_precision = False # No AMP for CPU
# Print the chosen device info
print(f"Training is running on: {backend} ({device})")
# **Initialize scaler (only for NVIDIA)**
if mixed_precision:
scaler = torch.amp.GradScaler()
else:
scaler = None # No scaler needed for AMD/CPU
return device, backend, scaler
if __name__ == "__main__":
device, backend, scaler = check_device()