File size: 1,840 Bytes
818a6a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1959409
 
 
f1418ca
 
1959409
 
 
 
 
 
 
 
658c460
1959409
818a6a6
1959409
658c460
818a6a6
 
1959409
 
046ea23
 
1959409
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DeviceManager:
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DeviceManager, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
            
        self._initialized = True
        self._current_device = None
        self.initialize_device()
    
    def initialize_device(self):
        try:
            if os.environ.get('SPACE_ID'):
                # 嘗試初始化 CUDA 設備
                if torch.cuda.is_available():
                    self._current_device = torch.device('cuda')
                    # 設置 CUDA 設備為可見
                    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
                    logger.info("CUDA device initialized successfully")
                else:
                    raise RuntimeError("CUDA not available")
            else:
                raise RuntimeError("Not in Spaces environment")
        except Exception as e:
            logger.warning(f"Using CPU due to: {e}")
            self._current_device = torch.device('cpu')
    
    def get_optimal_device(self):
        if self._current_device is None:
            self.initialize_device()
        return self._current_device

def to_device(tensor_or_model, device=None):
    """Helper function to move tensors or models to the appropriate device"""
    if device is None:
        device = DeviceManager().get_optimal_device()
    
    try:
        return tensor_or_model.to(device)
    except Exception as e:
        logger.warning(f"Failed to move to {device}, using CPU: {e}")
        return tensor_or_model.to('cpu')