Commit
•
a0e1504
1
Parent(s):
4a6dfff
Fix different devices bug when moving model from GPU to CPU (#5110)
Browse files* fix different devices bug
* extend _apply() instead of to() for a general fix
* Only apply if Detect() is last layer
Co-authored-by: Jebastin Nadar <njebastin10@gmail.com>
* Indent fix
* Add comment to yolo.py
* Add comment to common.py
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
- models/common.py +8 -0
- models/yolo.py +9 -0
models/common.py
CHANGED
@@ -289,6 +289,14 @@ class AutoShape(nn.Module):
|
|
289 |
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
290 |
return self
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
@torch.no_grad()
|
293 |
def forward(self, imgs, size=640, augment=False, profile=False):
|
294 |
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
|
|
289 |
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
290 |
return self
|
291 |
|
292 |
+
def _apply(self, fn):
|
293 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
294 |
+
self = super()._apply(fn)
|
295 |
+
m = self.model.model[-1] # Detect()
|
296 |
+
m.stride = fn(m.stride)
|
297 |
+
m.grid = list(map(fn, m.grid))
|
298 |
+
return self
|
299 |
+
|
300 |
@torch.no_grad()
|
301 |
def forward(self, imgs, size=640, augment=False, profile=False):
|
302 |
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
models/yolo.py
CHANGED
@@ -232,6 +232,15 @@ class Model(nn.Module):
|
|
232 |
def info(self, verbose=False, img_size=640): # print model information
|
233 |
model_info(self, verbose, img_size)
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
def parse_model(d, ch): # model_dict, input_channels(3)
|
237 |
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|
|
|
232 |
def info(self, verbose=False, img_size=640): # print model information
|
233 |
model_info(self, verbose, img_size)
|
234 |
|
235 |
+
def _apply(self, fn):
|
236 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
237 |
+
self = super()._apply(fn)
|
238 |
+
m = self.model[-1] # Detect()
|
239 |
+
if isinstance(m, Detect):
|
240 |
+
m.stride = fn(m.stride)
|
241 |
+
m.grid = list(map(fn, m.grid))
|
242 |
+
return self
|
243 |
+
|
244 |
|
245 |
def parse_model(d, ch): # model_dict, input_channels(3)
|
246 |
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|