|
import torch |
|
import os, sys |
|
|
|
if sys.platform == "darwin": |
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
|
now_dir = os.getcwd() |
|
sys.path.append(now_dir) |
|
from .logger.log import get_logger |
|
|
|
logger = get_logger("gpu") |
|
|
|
|
|
def select_device(min_memory=2047, experimental=False): |
|
if torch.cuda.is_available(): |
|
selected_gpu = 0 |
|
max_free_memory = -1 |
|
for i in range(torch.cuda.device_count()): |
|
props = torch.cuda.get_device_properties(i) |
|
free_memory = props.total_memory - torch.cuda.memory_reserved(i) |
|
if max_free_memory < free_memory: |
|
selected_gpu = i |
|
max_free_memory = free_memory |
|
free_memory_mb = max_free_memory / (1024 * 1024) |
|
if free_memory_mb < min_memory: |
|
logger.get_logger().warning( |
|
f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." |
|
) |
|
device = torch.device("cpu") |
|
else: |
|
device = torch.device(f"cuda:{selected_gpu}") |
|
elif torch.backends.mps.is_available(): |
|
""" |
|
Currently MPS is slower than CPU while needs more memory and core utility, |
|
so only enable this for experimental use. |
|
""" |
|
if experimental: |
|
|
|
logger.warn("experimantal: found apple GPU, using MPS.") |
|
device = torch.device("mps") |
|
else: |
|
logger.info("found Apple GPU, but use CPU.") |
|
device = torch.device("cpu") |
|
else: |
|
logger.warning("no GPU found, use CPU instead") |
|
device = torch.device("cpu") |
|
|
|
return device |
|
|