jebastin-nadar glenn-jocher commited on
Commit
a0e1504
1 Parent(s): 4a6dfff

Fix different devices bug when moving model from GPU to CPU (#5110)

Browse files

* fix different devices bug

* extend _apply() instead of to() for a general fix

* Only apply if Detect() is last layer

Co-authored-by: Jebastin Nadar <njebastin10@gmail.com>

* Indent fix

* Add comment to yolo.py

* Add comment to common.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (2) hide show
  1. models/common.py +8 -0
  2. models/yolo.py +9 -0
models/common.py CHANGED
@@ -289,6 +289,14 @@ class AutoShape(nn.Module):
289
  LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
290
  return self
291
 
 
 
 
 
 
 
 
 
292
  @torch.no_grad()
293
  def forward(self, imgs, size=640, augment=False, profile=False):
294
  # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
 
289
  LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
290
  return self
291
 
292
+ def _apply(self, fn):
293
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
294
+ self = super()._apply(fn)
295
+ m = self.model.model[-1] # Detect()
296
+ m.stride = fn(m.stride)
297
+ m.grid = list(map(fn, m.grid))
298
+ return self
299
+
300
  @torch.no_grad()
301
  def forward(self, imgs, size=640, augment=False, profile=False):
302
  # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
models/yolo.py CHANGED
@@ -232,6 +232,15 @@ class Model(nn.Module):
232
  def info(self, verbose=False, img_size=640): # print model information
233
  model_info(self, verbose, img_size)
234
 
 
 
 
 
 
 
 
 
 
235
 
236
  def parse_model(d, ch): # model_dict, input_channels(3)
237
  LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
 
232
  def info(self, verbose=False, img_size=640): # print model information
233
  model_info(self, verbose, img_size)
234
 
235
+ def _apply(self, fn):
236
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
237
+ self = super()._apply(fn)
238
+ m = self.model[-1] # Detect()
239
+ if isinstance(m, Detect):
240
+ m.stride = fn(m.stride)
241
+ m.grid = list(map(fn, m.grid))
242
+ return self
243
+
244
 
245
  def parse_model(d, ch): # model_dict, input_channels(3)
246
  LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))