DawnC commited on
Commit
046ea23
·
1 Parent(s): fbb22a3

Update device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +33 -12
device_manager.py CHANGED
@@ -25,25 +25,46 @@ class DeviceManager:
25
 
26
  def check_zero_gpu_availability(self):
27
  try:
28
- if 'SPACE_ID' in os.environ:
29
- api = HfApi()
30
- space_info = api.get_space_runtime(os.environ['SPACE_ID'])
31
- if hasattr(space_info, 'hardware') and space_info.hardware.get('zerogpu', False):
32
- return True
 
 
 
 
 
33
  except Exception as e:
34
  logger.warning(f"Error checking ZeroGPU availability: {e}")
35
- return False
36
 
37
  def get_optimal_device(self):
38
  if self._current_device is None:
39
  if self.check_zero_gpu_availability():
40
  try:
41
- self._current_device = torch.device('cuda')
42
- logger.info("Using ZeroGPU")
43
- except Exception:
 
 
 
 
 
44
  self._current_device = torch.device('cpu')
45
- logger.info("Failed to use ZeroGPU, falling back to CPU")
46
  else:
47
  self._current_device = torch.device('cpu')
48
- logger.info("Using CPU")
49
- return self._current_device
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def check_zero_gpu_availability(self):
27
  try:
28
+ # 檢查 Hugging Face Space 環境變數
29
+ if not os.environ.get('SPACE_ID'):
30
+ return False
31
+
32
+ # 檢查是否在 Spaces 環境中並且啟用了 ZeroGPU
33
+ if os.environ.get('ZERO_GPU_AVAILABLE') == '1':
34
+ return True
35
+
36
+ return False
37
+
38
  except Exception as e:
39
  logger.warning(f"Error checking ZeroGPU availability: {e}")
40
+ return False
41
 
42
  def get_optimal_device(self):
43
  if self._current_device is None:
44
  if self.check_zero_gpu_availability():
45
  try:
46
+ # 確保 CUDA 可用
47
+ if torch.cuda.is_available():
48
+ self._current_device = torch.device('cuda')
49
+ logger.info("Using ZeroGPU")
50
+ else:
51
+ raise RuntimeError("CUDA not available")
52
+ except Exception as e:
53
+ logger.warning(f"Failed to initialize ZeroGPU: {e}")
54
  self._current_device = torch.device('cpu')
55
+ logger.info("Fallback to CPU due to GPU initialization failure")
56
  else:
57
  self._current_device = torch.device('cpu')
58
+ logger.info("Using CPU (ZeroGPU not available)")
59
+ return self._current_device
60
+
61
+ def move_to_device(self, tensor_or_model):
62
+ device = self.get_optimal_device()
63
+ try:
64
+ if hasattr(tensor_or_model, 'to'):
65
+ return tensor_or_model.to(device)
66
+ except Exception:
67
+ self._current_device = torch.device('cpu')
68
+ if hasattr(tensor_or_model, 'to'):
69
+ return tensor_or_model.to('cpu')
70
+ return tensor_or_model