Spaces:
Running
Running
import torch | |
import importlib | |
import subprocess | |
import sys | |
def install_package(package_name): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) | |
def check_device(): | |
# **Check for NVIDIA GPU (CUDA)** | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") # Use NVIDIA GPU | |
backend = "CUDA (NVIDIA)" | |
mixed_precision = True # Use Automatic Mixed Precision (AMP) | |
# **If no NVIDIA GPU, check for AMD GPU (DirectML) only in Windows** | |
else: | |
try: | |
# Only try DirectML if the environment is Windows and DirectML is installed | |
if "win32" in sys.platform: | |
torch_directml = importlib.import_module("torch_directml") | |
if torch_directml.device_count() > 0: | |
device = torch_directml.device() # Use AMD GPU with DirectML | |
backend = "DirectML (AMD)" | |
mixed_precision = False # No AMP for AMD GPU | |
else: | |
raise ImportError # AMD GPU not found | |
else: | |
device = torch.device("cpu") | |
backend = "CPU" | |
mixed_precision = False # No AMP for CPU | |
except ImportError: | |
# If DirectML is not installed or AMD GPU not found | |
device = torch.device("cpu") | |
backend = "CPU" | |
mixed_precision = False # No AMP for CPU | |
# Print the chosen device info | |
print(f"Training is running on: {backend} ({device})") | |
# **Initialize scaler (only for NVIDIA)** | |
if mixed_precision: | |
scaler = torch.amp.GradScaler() | |
else: | |
scaler = None # No scaler needed for AMD/CPU | |
return device, backend, scaler | |
if __name__ == "__main__": | |
device, backend, scaler = check_device() | |