glenn-jocher commited on
Commit
a18efc3
1 Parent(s): aa02b94

Add variable-stride inference support (#2091)

Browse files
Files changed (2) hide show
  1. detect.py +4 -3
  2. utils/datasets.py +13 -10
detect.py CHANGED
@@ -31,7 +31,8 @@ def detect(save_img=False):
31
 
32
  # Load model
33
  model = attempt_load(weights, map_location=device) # load FP32 model
34
- imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
 
35
  if half:
36
  model.half() # to FP16
37
 
@@ -46,10 +47,10 @@ def detect(save_img=False):
46
  if webcam:
47
  view_img = True
48
  cudnn.benchmark = True # set True to speed up constant image size inference
49
- dataset = LoadStreams(source, img_size=imgsz)
50
  else:
51
  save_img = True
52
- dataset = LoadImages(source, img_size=imgsz)
53
 
54
  # Get names and colors
55
  names = model.module.names if hasattr(model, 'module') else model.names
 
31
 
32
  # Load model
33
  model = attempt_load(weights, map_location=device) # load FP32 model
34
+ stride = int(model.stride.max()) # model stride
35
+ imgsz = check_img_size(imgsz, s=stride) # check img_size
36
  if half:
37
  model.half() # to FP16
38
 
 
47
  if webcam:
48
  view_img = True
49
  cudnn.benchmark = True # set True to speed up constant image size inference
50
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride)
51
  else:
52
  save_img = True
53
+ dataset = LoadImages(source, img_size=imgsz, stride=stride)
54
 
55
  # Get names and colors
56
  names = model.module.names if hasattr(model, 'module') else model.names
utils/datasets.py CHANGED
@@ -119,7 +119,7 @@ class _RepeatSampler(object):
119
 
120
 
121
  class LoadImages: # for inference
122
- def __init__(self, path, img_size=640):
123
  p = str(Path(path)) # os-agnostic
124
  p = os.path.abspath(p) # absolute path
125
  if '*' in p:
@@ -136,6 +136,7 @@ class LoadImages: # for inference
136
  ni, nv = len(images), len(videos)
137
 
138
  self.img_size = img_size
 
139
  self.files = images + videos
140
  self.nf = ni + nv # number of files
141
  self.video_flag = [False] * ni + [True] * nv
@@ -181,7 +182,7 @@ class LoadImages: # for inference
181
  print(f'image {self.count}/{self.nf} {path}: ', end='')
182
 
183
  # Padded resize
184
- img = letterbox(img0, new_shape=self.img_size)[0]
185
 
186
  # Convert
187
  img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
@@ -199,8 +200,9 @@ class LoadImages: # for inference
199
 
200
 
201
  class LoadWebcam: # for inference
202
- def __init__(self, pipe='0', img_size=640):
203
  self.img_size = img_size
 
204
 
205
  if pipe.isnumeric():
206
  pipe = eval(pipe) # local camera
@@ -243,7 +245,7 @@ class LoadWebcam: # for inference
243
  print(f'webcam {self.count}: ', end='')
244
 
245
  # Padded resize
246
- img = letterbox(img0, new_shape=self.img_size)[0]
247
 
248
  # Convert
249
  img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
@@ -256,9 +258,10 @@ class LoadWebcam: # for inference
256
 
257
 
258
  class LoadStreams: # multiple IP or RTSP cameras
259
- def __init__(self, sources='streams.txt', img_size=640):
260
  self.mode = 'stream'
261
  self.img_size = img_size
 
262
 
263
  if os.path.isfile(sources):
264
  with open(sources, 'r') as f:
@@ -284,7 +287,7 @@ class LoadStreams: # multiple IP or RTSP cameras
284
  print('') # newline
285
 
286
  # check for common shapes
287
- s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
288
  self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
289
  if not self.rect:
290
  print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
@@ -313,7 +316,7 @@ class LoadStreams: # multiple IP or RTSP cameras
313
  raise StopIteration
314
 
315
  # Letterbox
316
- img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
317
 
318
  # Stack
319
  img = np.stack(img, 0)
@@ -784,8 +787,8 @@ def replicate(img, labels):
784
  return img, labels
785
 
786
 
787
- def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
788
- # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
789
  shape = img.shape[:2] # current shape [height, width]
790
  if isinstance(new_shape, int):
791
  new_shape = (new_shape, new_shape)
@@ -800,7 +803,7 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
800
  new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
801
  dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
802
  if auto: # minimum rectangle
803
- dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
804
  elif scaleFill: # stretch
805
  dw, dh = 0.0, 0.0
806
  new_unpad = (new_shape[1], new_shape[0])
 
119
 
120
 
121
  class LoadImages: # for inference
122
+ def __init__(self, path, img_size=640, stride=32):
123
  p = str(Path(path)) # os-agnostic
124
  p = os.path.abspath(p) # absolute path
125
  if '*' in p:
 
136
  ni, nv = len(images), len(videos)
137
 
138
  self.img_size = img_size
139
+ self.stride = stride
140
  self.files = images + videos
141
  self.nf = ni + nv # number of files
142
  self.video_flag = [False] * ni + [True] * nv
 
182
  print(f'image {self.count}/{self.nf} {path}: ', end='')
183
 
184
  # Padded resize
185
+ img = letterbox(img0, self.img_size, stride=self.stride)[0]
186
 
187
  # Convert
188
  img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
 
200
 
201
 
202
  class LoadWebcam: # for inference
203
+ def __init__(self, pipe='0', img_size=640, stride=32):
204
  self.img_size = img_size
205
+ self.stride = stride
206
 
207
  if pipe.isnumeric():
208
  pipe = eval(pipe) # local camera
 
245
  print(f'webcam {self.count}: ', end='')
246
 
247
  # Padded resize
248
+ img = letterbox(img0, self.img_size, stride=self.stride)[0]
249
 
250
  # Convert
251
  img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
 
258
 
259
 
260
  class LoadStreams: # multiple IP or RTSP cameras
261
+ def __init__(self, sources='streams.txt', img_size=640, stride=32):
262
  self.mode = 'stream'
263
  self.img_size = img_size
264
+ self.stride = stride
265
 
266
  if os.path.isfile(sources):
267
  with open(sources, 'r') as f:
 
287
  print('') # newline
288
 
289
  # check for common shapes
290
+ s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
291
  self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
292
  if not self.rect:
293
  print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
 
316
  raise StopIteration
317
 
318
  # Letterbox
319
+ img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
320
 
321
  # Stack
322
  img = np.stack(img, 0)
 
787
  return img, labels
788
 
789
 
790
+ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
791
+ # Resize and pad image while meeting stride-multiple constraints
792
  shape = img.shape[:2] # current shape [height, width]
793
  if isinstance(new_shape, int):
794
  new_shape = (new_shape, new_shape)
 
803
  new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
804
  dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
805
  if auto: # minimum rectangle
806
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
807
  elif scaleFill: # stretch
808
  dw, dh = 0.0, 0.0
809
  new_unpad = (new_shape[1], new_shape[0])