Spaces:
Sleeping
Sleeping
File size: 1,811 Bytes
e4e2851 8e90922 e4e2851 14ee6e4 8e90922 e4e2851 8e90922 e4e2851 f222f88 8e90922 f222f88 8e90922 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from functools import wraps
import torch
import os
import logging
import spaces
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DeviceManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(DeviceManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._current_device = None
try:
if os.environ.get('SPACE_ID'):
# 使用 spaces 的 GPU wrapper 進行初始化
@spaces.GPU
def init_gpu():
return torch.device('cuda')
self._current_device = init_gpu()
logger.info("ZeroGPU initialized successfully")
else:
self._current_device = torch.device('cpu')
except Exception as e:
logger.warning(f"Failed to initialize ZeroGPU: {e}")
self._current_device = torch.device('cpu')
def get_optimal_device(self):
return self._current_device
def device_handler(func):
"""Decorator for handling device placement with ZeroGPU support"""
@spaces.GPU
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except RuntimeError as e:
if "out of memory" in str(e) or "CUDA" in str(e):
logger.warning("ZeroGPU unavailable, falling back to CPU")
device_mgr = DeviceManager()
device_mgr._current_device = torch.device('cpu')
return await func(*args, **kwargs)
raise e
return wrapper |