|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import logging |
|
|
|
|
|
|
|
|
def get_device(number: int, logger: logging.Logger = None): |
|
|
""" |
|
|
Configures PyTorch to use a specified GPU by its index number, |
|
|
or falls back to CPU if CUDA is not available. |
|
|
|
|
|
Args: |
|
|
number (int): The index number of the GPU to use. |
|
|
logger (logging.Logger, optional): Logger for logging GPU info. |
|
|
|
|
|
Returns: |
|
|
torch.device: The selected torch device (GPU or CPU). |
|
|
""" |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
if logger: |
|
|
logger.warning("CUDA is not available. Falling back to CPU.") |
|
|
return torch.device('cpu') |
|
|
|
|
|
|
|
|
if number >= torch.cuda.device_count() or number < 0: |
|
|
raise ValueError( |
|
|
f"GPU number {number} is not valid. Available GPU indices range from 0 to {torch.cuda.device_count() - 1}.") |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
torch.cuda.reset_accumulated_memory_stats() |
|
|
|
|
|
|
|
|
torch.cuda.set_device(number) |
|
|
if logger: |
|
|
logger.info(f"PyTorch is now configured to use GPU {number}: {torch.cuda.get_device_name(number)}") |
|
|
|
|
|
device_name = torch.cuda.get_device_name(number) |
|
|
total_mem = torch.cuda.get_device_properties(number).total_memory / 1024 ** 2 |
|
|
mem_allocated = torch.cuda.memory_allocated(number) / 1024 ** 2 |
|
|
mem_reserved = torch.cuda.memory_reserved(number) / 1024 ** 2 |
|
|
max_allocated = torch.cuda.max_memory_allocated(number) / 1024 ** 2 |
|
|
max_reserved = torch.cuda.max_memory_reserved(number) / 1024 ** 2 |
|
|
|
|
|
logger.info(f"[GPU {number} - {device_name}] Memory Stats:") |
|
|
logger.info(f" Total Memory : {total_mem:.2f} MB") |
|
|
logger.info(f" Currently Allocated : {mem_allocated:.2f} MB") |
|
|
logger.info(f" Currently Reserved : {mem_reserved:.2f} MB") |
|
|
logger.info(f" Max Allocated : {max_allocated:.2f} MB") |
|
|
logger.info(f" Max Reserved : {max_reserved:.2f} MB") |
|
|
|
|
|
return torch.device(f'cuda:{number}') |
|
|
|
|
|
|
|
|
|
|
|
|