imyhxy commited on
Commit
4e841b9
1 Parent(s): 9708cf5

Reuse `de_parallel()` rather than `is_parallel()` (#6354)

Browse files
Files changed (2) hide show
  1. utils/loss.py +2 -2
  2. utils/torch_utils.py +2 -2
utils/loss.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import torch.nn as nn
8
 
9
  from utils.metrics import bbox_iou
10
- from utils.torch_utils import is_parallel
11
 
12
 
13
  def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
@@ -107,7 +107,7 @@ class ComputeLoss:
107
  if g > 0:
108
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
109
 
110
- det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
111
  self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
112
  self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
113
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
 
7
  import torch.nn as nn
8
 
9
  from utils.metrics import bbox_iou
10
+ from utils.torch_utils import de_parallel
11
 
12
 
13
  def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
 
107
  if g > 0:
108
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
109
 
110
+ det = de_parallel(model).model[-1] # Detect() module
111
  self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
112
  self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
113
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
utils/torch_utils.py CHANGED
@@ -295,7 +295,7 @@ class ModelEMA:
295
 
296
  def __init__(self, model, decay=0.9999, updates=0):
297
  # Create EMA
298
- self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
299
  # if next(model.parameters()).device.type != 'cpu':
300
  # self.ema.half() # FP16 EMA
301
  self.updates = updates # number of EMA updates
@@ -309,7 +309,7 @@ class ModelEMA:
309
  self.updates += 1
310
  d = self.decay(self.updates)
311
 
312
- msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
313
  for k, v in self.ema.state_dict().items():
314
  if v.dtype.is_floating_point:
315
  v *= d
 
295
 
296
  def __init__(self, model, decay=0.9999, updates=0):
297
  # Create EMA
298
+ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
299
  # if next(model.parameters()).device.type != 'cpu':
300
  # self.ema.half() # FP16 EMA
301
  self.updates = updates # number of EMA updates
 
309
  self.updates += 1
310
  d = self.decay(self.updates)
311
 
312
+ msd = de_parallel(model).state_dict() # model state_dict
313
  for k, v in self.ema.state_dict().items():
314
  if v.dtype.is_floating_point:
315
  v *= d