glenn-jocher
commited on
Commit
·
a1c8406
1
Parent(s):
2377e5f
EMA and non_blocking=True
Browse files
test.py
CHANGED
@@ -69,7 +69,7 @@ def test(data,
|
|
69 |
loss = torch.zeros(3, device=device)
|
70 |
jdict, stats, ap, ap_class = [], [], [], []
|
71 |
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
|
72 |
-
img = img.to(device)
|
73 |
img = img.half() if half else img.float() # uint8 to fp16/32
|
74 |
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
75 |
targets = targets.to(device)
|
|
|
69 |
loss = torch.zeros(3, device=device)
|
70 |
jdict, stats, ap, ap_class = [], [], [], []
|
71 |
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
|
72 |
+
img = img.to(device, non_blocking=True)
|
73 |
img = img.half() if half else img.float() # uint8 to fp16/32
|
74 |
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
75 |
targets = targets.to(device)
|
train.py
CHANGED
@@ -193,7 +193,7 @@ def train(hyp):
|
|
193 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
194 |
|
195 |
# Exponential moving average
|
196 |
-
ema = torch_utils.ModelEMA(model
|
197 |
|
198 |
# Start training
|
199 |
t0 = time.time()
|
@@ -223,7 +223,7 @@ def train(hyp):
|
|
223 |
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
224 |
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
225 |
ni = i + nb * epoch # number integrated batches (since train start)
|
226 |
-
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
227 |
|
228 |
# Warmup
|
229 |
if ni <= nw:
|
|
|
193 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
194 |
|
195 |
# Exponential moving average
|
196 |
+
ema = torch_utils.ModelEMA(model)
|
197 |
|
198 |
# Start training
|
199 |
t0 = time.time()
|
|
|
223 |
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
224 |
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
225 |
ni = i + nb * epoch # number integrated batches (since train start)
|
226 |
+
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
|
227 |
|
228 |
# Warmup
|
229 |
if ni <= nw:
|