Commit
•
035ac82
1
Parent(s):
69ea70c
Fix torch multi-GPU --device error (#1701)
Browse files* Fix torch GPU error
* Update torch_utils.py
single-line device =
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
- utils/torch_utils.py +3 -2
utils/torch_utils.py
CHANGED
@@ -75,13 +75,14 @@ def time_synchronized():
|
|
75 |
return time.time()
|
76 |
|
77 |
|
78 |
-
def profile(x, ops, n=100, device=
|
79 |
# profile a pytorch module or list of modules. Example usage:
|
80 |
# x = torch.randn(16, 3, 640, 640) # input
|
81 |
# m1 = lambda x: x * torch.sigmoid(x)
|
82 |
# m2 = nn.SiLU()
|
83 |
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
84 |
-
|
|
|
85 |
x = x.to(device)
|
86 |
x.requires_grad = True
|
87 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
|
|
75 |
return time.time()
|
76 |
|
77 |
|
78 |
+
def profile(x, ops, n=100, device=None):
|
79 |
# profile a pytorch module or list of modules. Example usage:
|
80 |
# x = torch.randn(16, 3, 640, 640) # input
|
81 |
# m1 = lambda x: x * torch.sigmoid(x)
|
82 |
# m2 = nn.SiLU()
|
83 |
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
84 |
+
|
85 |
+
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
86 |
x = x.to(device)
|
87 |
x.requires_grad = True
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|