glenn-jocher commited on
Commit
5bdb28e
1 Parent(s): c77a5a8

Default PyTorch Hub to `autocast(False)` (#5926)

Browse files
Files changed (1) hide show
  1. models/common.py +4 -2
models/common.py CHANGED
@@ -443,6 +443,7 @@ class AutoShape(nn.Module):
443
  multi_label = False # NMS multiple labels per box
444
  classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
445
  max_det = 1000 # maximum number of detections per image
 
446
 
447
  def __init__(self, model):
448
  super().__init__()
@@ -476,8 +477,9 @@ class AutoShape(nn.Module):
476
 
477
  t = [time_sync()]
478
  p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
 
479
  if isinstance(imgs, torch.Tensor): # torch
480
- with amp.autocast(enabled=p.device.type != 'cpu'):
481
  return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
482
 
483
  # Pre-process
@@ -506,7 +508,7 @@ class AutoShape(nn.Module):
506
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
507
  t.append(time_sync())
508
 
509
- with amp.autocast(enabled=p.device.type != 'cpu'):
510
  # Inference
511
  y = self.model(x, augment, profile) # forward
512
  t.append(time_sync())
 
443
  multi_label = False # NMS multiple labels per box
444
  classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
445
  max_det = 1000 # maximum number of detections per image
446
+ amp = False # Automatic Mixed Precision (AMP) inference
447
 
448
  def __init__(self, model):
449
  super().__init__()
 
477
 
478
  t = [time_sync()]
479
  p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
480
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
481
  if isinstance(imgs, torch.Tensor): # torch
482
+ with amp.autocast(enabled=autocast):
483
  return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
484
 
485
  # Pre-process
 
508
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
509
  t.append(time_sync())
510
 
511
+ with amp.autocast(enabled=autocast):
512
  # Inference
513
  y = self.model(x, augment, profile) # forward
514
  t.append(time_sync())