Reuse `de_parallel()` rather than `is_parallel()` (#6354)
Browse files- utils/loss.py +2 -2
- utils/torch_utils.py +2 -2
utils/loss.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from utils.metrics import bbox_iou
|
10 |
-
from utils.torch_utils import
|
11 |
|
12 |
|
13 |
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
@@ -107,7 +107,7 @@ class ComputeLoss:
|
|
107 |
if g > 0:
|
108 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
109 |
|
110 |
-
det =
|
111 |
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
|
112 |
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
|
113 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from utils.metrics import bbox_iou
|
10 |
+
from utils.torch_utils import de_parallel
|
11 |
|
12 |
|
13 |
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
|
|
107 |
if g > 0:
|
108 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
109 |
|
110 |
+
det = de_parallel(model).model[-1] # Detect() module
|
111 |
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
|
112 |
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
|
113 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
|
utils/torch_utils.py
CHANGED
@@ -295,7 +295,7 @@ class ModelEMA:
|
|
295 |
|
296 |
def __init__(self, model, decay=0.9999, updates=0):
|
297 |
# Create EMA
|
298 |
-
self.ema = deepcopy(
|
299 |
# if next(model.parameters()).device.type != 'cpu':
|
300 |
# self.ema.half() # FP16 EMA
|
301 |
self.updates = updates # number of EMA updates
|
@@ -309,7 +309,7 @@ class ModelEMA:
|
|
309 |
self.updates += 1
|
310 |
d = self.decay(self.updates)
|
311 |
|
312 |
-
msd =
|
313 |
for k, v in self.ema.state_dict().items():
|
314 |
if v.dtype.is_floating_point:
|
315 |
v *= d
|
|
|
295 |
|
296 |
def __init__(self, model, decay=0.9999, updates=0):
|
297 |
# Create EMA
|
298 |
+
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
299 |
# if next(model.parameters()).device.type != 'cpu':
|
300 |
# self.ema.half() # FP16 EMA
|
301 |
self.updates = updates # number of EMA updates
|
|
|
309 |
self.updates += 1
|
310 |
d = self.decay(self.updates)
|
311 |
|
312 |
+
msd = de_parallel(model).state_dict() # model state_dict
|
313 |
for k, v in self.ema.state_dict().items():
|
314 |
if v.dtype.is_floating_point:
|
315 |
v *= d
|