PawMatchAI / device_manager.py
DawnC's picture
Update device_manager.py
0e5cc70
raw
history blame
2.18 kB
import torch
import os
import logging
import spaces
from functools import wraps
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DeviceManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(DeviceManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self.device = self._initialize_device()
def _initialize_device(self):
"""初始化並確定使用的設備"""
try:
# 檢查是否在 Spaces 環境且有 ZeroGPU
if os.environ.get('SPACE_ID') and torch.cuda.is_available():
logger.info("ZeroGPU environment detected")
return 'cuda'
except Exception as e:
logger.warning(f"Unable to initialize ZeroGPU: {e}")
logger.info("Using CPU")
return 'cpu'
def get_device(self):
"""獲取當前設備"""
return self.device
def to_device(self, model_or_tensor):
"""將模型或張量移到正確的設備上"""
try:
if hasattr(model_or_tensor, 'to'):
return model_or_tensor.to(self.device)
except Exception as e:
logger.warning(f"Failed to move to {self.device}, using CPU: {e}")
self.device = 'cpu'
return model_or_tensor.to('cpu')
return model_or_tensor
def adaptive_gpu(duration=60):
"""結合 spaces.GPU 和 CPU 降級的裝飾器"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
device_mgr = DeviceManager()
if device_mgr.get_device() == 'cuda':
# 在 ZeroGPU 環境中使用 spaces.GPU
decorated = spaces.GPU(duration=duration)(func)
return await decorated(*args, **kwargs)
else:
# 在 CPU 環境中直接執行
return await func(*args, **kwargs)
return wrapper
return decorator