nbansal commited on
Commit
2c33aa3
1 Parent(s): 47f92c7

Minor modifications to how device is selected

Browse files
Files changed (2) hide show
  1. tests.py +5 -2
  2. utils.py +1 -1
tests.py CHANGED
@@ -31,10 +31,13 @@ class TestUtils(unittest.TestCase):
31
  self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
32
 
33
  # Test list input with duplicate elements
34
- self.assertEqual(get_gpu([0, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])
35
 
36
  # Test list input with duplicate elements of different types
37
- self.assertEqual(get_gpu([True, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])
 
 
 
38
 
39
  # Test list input with all integers
40
  self.assertEqual(get_gpu(list(range(gpu_count))),
 
31
  self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
32
 
33
  # Test list input with duplicate elements
34
+ self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
35
 
36
  # Test list input with duplicate elements of different types
37
+ self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
38
+
39
+ # Test list input but only one element
40
+ self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
41
 
42
  # Test list input with all integers
43
  self.assertEqual(get_gpu(list(range(gpu_count))),
utils.py CHANGED
@@ -73,7 +73,7 @@ def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
73
  result.append(device)
74
  else:
75
  result.append(device)
76
- return result
77
  else:
78
  return _get_single_device(gpu)
79
 
 
73
  result.append(device)
74
  else:
75
  result.append(device)
76
+ return result[0] if len(result) == 1 else result
77
  else:
78
  return _get_single_device(gpu)
79