glenn-jocher commited on
Commit
a1c8406
1 Parent(s): 2377e5f

EMA and non_blocking=True

Browse files
Files changed (2) hide show
  1. test.py +1 -1
  2. train.py +2 -2
test.py CHANGED
@@ -69,7 +69,7 @@ def test(data,
69
  loss = torch.zeros(3, device=device)
70
  jdict, stats, ap, ap_class = [], [], [], []
71
  for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
72
- img = img.to(device)
73
  img = img.half() if half else img.float() # uint8 to fp16/32
74
  img /= 255.0 # 0 - 255 to 0.0 - 1.0
75
  targets = targets.to(device)
 
69
  loss = torch.zeros(3, device=device)
70
  jdict, stats, ap, ap_class = [], [], [], []
71
  for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
72
+ img = img.to(device, non_blocking=True)
73
  img = img.half() if half else img.float() # uint8 to fp16/32
74
  img /= 255.0 # 0 - 255 to 0.0 - 1.0
75
  targets = targets.to(device)
train.py CHANGED
@@ -193,7 +193,7 @@ def train(hyp):
193
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
194
 
195
  # Exponential moving average
196
- ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate)
197
 
198
  # Start training
199
  t0 = time.time()
@@ -223,7 +223,7 @@ def train(hyp):
223
  pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
224
  for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
225
  ni = i + nb * epoch # number integrated batches (since train start)
226
- imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
227
 
228
  # Warmup
229
  if ni <= nw:
 
193
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
194
 
195
  # Exponential moving average
196
+ ema = torch_utils.ModelEMA(model)
197
 
198
  # Start training
199
  t0 = time.time()
 
223
  pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
224
  for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
225
  ni = i + nb * epoch # number integrated batches (since train start)
226
+ imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
227
 
228
  # Warmup
229
  if ni <= nw: