glenn-jocher
commited on
Commit
•
5bdb28e
1
Parent(s):
c77a5a8
Default PyTorch Hub to `autocast(False)` (#5926)
Browse files- models/common.py +4 -2
models/common.py
CHANGED
@@ -443,6 +443,7 @@ class AutoShape(nn.Module):
|
|
443 |
multi_label = False # NMS multiple labels per box
|
444 |
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
445 |
max_det = 1000 # maximum number of detections per image
|
|
|
446 |
|
447 |
def __init__(self, model):
|
448 |
super().__init__()
|
@@ -476,8 +477,9 @@ class AutoShape(nn.Module):
|
|
476 |
|
477 |
t = [time_sync()]
|
478 |
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
|
|
|
479 |
if isinstance(imgs, torch.Tensor): # torch
|
480 |
-
with amp.autocast(enabled=
|
481 |
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
482 |
|
483 |
# Pre-process
|
@@ -506,7 +508,7 @@ class AutoShape(nn.Module):
|
|
506 |
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
507 |
t.append(time_sync())
|
508 |
|
509 |
-
with amp.autocast(enabled=
|
510 |
# Inference
|
511 |
y = self.model(x, augment, profile) # forward
|
512 |
t.append(time_sync())
|
|
|
443 |
multi_label = False # NMS multiple labels per box
|
444 |
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
445 |
max_det = 1000 # maximum number of detections per image
|
446 |
+
amp = False # Automatic Mixed Precision (AMP) inference
|
447 |
|
448 |
def __init__(self, model):
|
449 |
super().__init__()
|
|
|
477 |
|
478 |
t = [time_sync()]
|
479 |
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
|
480 |
+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
481 |
if isinstance(imgs, torch.Tensor): # torch
|
482 |
+
with amp.autocast(enabled=autocast):
|
483 |
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
484 |
|
485 |
# Pre-process
|
|
|
508 |
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
509 |
t.append(time_sync())
|
510 |
|
511 |
+
with amp.autocast(enabled=autocast):
|
512 |
# Inference
|
513 |
y = self.model(x, augment, profile) # forward
|
514 |
t.append(time_sync())
|