glenn-jocher commited on
Commit
b8557f8
1 Parent(s): 3b06225

add stride to datasets.py

Browse files
Files changed (3) hide show
  1. test.py +1 -0
  2. train.py +4 -2
  3. utils/datasets.py +2 -2
test.py CHANGED
@@ -73,6 +73,7 @@ def test(data,
73
  batch_size,
74
  rect=True, # rectangular inference
75
  single_cls=opt.single_cls, # single class mode
 
76
  pad=0.5) # padding
77
  batch_size = min(batch_size, len(dataset))
78
  nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
 
73
  batch_size,
74
  rect=True, # rectangular inference
75
  single_cls=opt.single_cls, # single class mode
76
+ stride=int(max(model.stride)), # model stride
77
  pad=0.5) # padding
78
  batch_size = min(batch_size, len(dataset))
79
  nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
train.py CHANGED
@@ -160,7 +160,8 @@ def train(hyp):
160
  hyp=hyp, # augmentation hyperparameters
161
  rect=opt.rect, # rectangular training
162
  cache_images=opt.cache_images,
163
- single_cls=opt.single_cls)
 
164
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
165
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
166
 
@@ -179,7 +180,8 @@ def train(hyp):
179
  hyp=hyp,
180
  rect=True,
181
  cache_images=opt.cache_images,
182
- single_cls=opt.single_cls),
 
183
  batch_size=batch_size,
184
  num_workers=nw,
185
  pin_memory=True,
 
160
  hyp=hyp, # augmentation hyperparameters
161
  rect=opt.rect, # rectangular training
162
  cache_images=opt.cache_images,
163
+ single_cls=opt.single_cls,
164
+ stride=gs)
165
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
166
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
167
 
 
180
  hyp=hyp,
181
  rect=True,
182
  cache_images=opt.cache_images,
183
+ single_cls=opt.single_cls,
184
+ stride=gs),
185
  batch_size=batch_size,
186
  num_workers=nw,
187
  pin_memory=True,
utils/datasets.py CHANGED
@@ -258,7 +258,7 @@ class LoadStreams: # multiple IP or RTSP cameras
258
 
259
  class LoadImagesAndLabels(Dataset): # for training/testing
260
  def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
261
- cache_images=False, single_cls=False, pad=0.0):
262
  try:
263
  path = str(Path(path)) # os-agnostic
264
  parent = str(Path(path).parent) + os.sep
@@ -325,7 +325,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
325
  elif mini > 1:
326
  shapes[i] = [1, 1 / mini]
327
 
328
- self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32. + pad).astype(np.int) * 32
329
 
330
  # Cache labels
331
  self.imgs = [None] * n
 
258
 
259
  class LoadImagesAndLabels(Dataset): # for training/testing
260
  def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
261
+ cache_images=False, single_cls=False, stride=32, pad=0.0):
262
  try:
263
  path = str(Path(path)) # os-agnostic
264
  parent = str(Path(path).parent) + os.sep
 
325
  elif mini > 1:
326
  shapes[i] = [1, 1 / mini]
327
 
328
+ self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
329
 
330
  # Cache labels
331
  self.imgs = [None] * n