Nanobit glenn-jocher commited on
Commit
035ac82
1 Parent(s): 69ea70c

Fix torch multi-GPU --device error (#1701)

Browse files

* Fix torch GPU error

* Update torch_utils.py

single-line device =

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (1) hide show
  1. utils/torch_utils.py +3 -2
utils/torch_utils.py CHANGED
@@ -75,13 +75,14 @@ def time_synchronized():
75
  return time.time()
76
 
77
 
78
- def profile(x, ops, n=100, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
79
  # profile a pytorch module or list of modules. Example usage:
80
  # x = torch.randn(16, 3, 640, 640) # input
81
  # m1 = lambda x: x * torch.sigmoid(x)
82
  # m2 = nn.SiLU()
83
  # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
84
-
 
85
  x = x.to(device)
86
  x.requires_grad = True
87
  print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
 
75
  return time.time()
76
 
77
 
78
+ def profile(x, ops, n=100, device=None):
79
  # profile a pytorch module or list of modules. Example usage:
80
  # x = torch.randn(16, 3, 640, 640) # input
81
  # m1 = lambda x: x * torch.sigmoid(x)
82
  # m2 = nn.SiLU()
83
  # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
84
+
85
+ device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
86
  x = x.to(device)
87
  x.requires_grad = True
88
  print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')