PawMatchAI / device_manager.py
DawnC's picture
Update device_manager.py
1959409
raw
history blame
1.84 kB
import torch
import os
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DeviceManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(DeviceManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._current_device = None
self.initialize_device()
def initialize_device(self):
try:
if os.environ.get('SPACE_ID'):
# 嘗試初始化 CUDA 設備
if torch.cuda.is_available():
self._current_device = torch.device('cuda')
# 設置 CUDA 設備為可見
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
logger.info("CUDA device initialized successfully")
else:
raise RuntimeError("CUDA not available")
else:
raise RuntimeError("Not in Spaces environment")
except Exception as e:
logger.warning(f"Using CPU due to: {e}")
self._current_device = torch.device('cpu')
def get_optimal_device(self):
if self._current_device is None:
self.initialize_device()
return self._current_device
def to_device(tensor_or_model, device=None):
"""Helper function to move tensors or models to the appropriate device"""
if device is None:
device = DeviceManager().get_optimal_device()
try:
return tensor_or_model.to(device)
except Exception as e:
logger.warning(f"Failed to move to {device}, using CPU: {e}")
return tensor_or_model.to('cpu')