ketan-b glenn-jocher commited on
Commit
9d86b54
1 Parent(s): d3e9d69

Add multi-stream saving feature (#3864)

Browse files

* Added the recording feature for multiple streams

Thanks for the very cool repo!!
I was trying to record multiple feeds at the same time, but the current version of the detector only had one video writer and one vid_path!
So the streams were not being saved and only were initialized with one frame and this process didn't record the whole thing.

Fix:
I made a list of `vid_writer` and `vid_path` and the `i` from the loop over the `pred` took care of the writer which need to work!

I hope this helps, Thanks!

* Cleanup list lengths

* batch size variable

* Update datasets.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (2) hide show
  1. detect.py +10 -8
  2. utils/datasets.py +1 -1
detect.py CHANGED
@@ -76,14 +76,16 @@ def run(weights='yolov5s.pt', # model.pt path(s)
76
  modelc = load_classifier(name='resnet50', n=2) # initialize
77
  modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
78
 
79
- # Set Dataloader
80
- vid_path, vid_writer = None, None
81
  if webcam:
82
  view_img = check_imshow()
83
  cudnn.benchmark = True # set True to speed up constant image size inference
84
  dataset = LoadStreams(source, img_size=imgsz, stride=stride)
 
85
  else:
86
  dataset = LoadImages(source, img_size=imgsz, stride=stride)
 
 
87
 
88
  # Run inference
89
  if device.type != 'cpu':
@@ -158,10 +160,10 @@ def run(weights='yolov5s.pt', # model.pt path(s)
158
  if dataset.mode == 'image':
159
  cv2.imwrite(save_path, im0)
160
  else: # 'video' or 'stream'
161
- if vid_path != save_path: # new video
162
- vid_path = save_path
163
- if isinstance(vid_writer, cv2.VideoWriter):
164
- vid_writer.release() # release previous video writer
165
  if vid_cap: # video
166
  fps = vid_cap.get(cv2.CAP_PROP_FPS)
167
  w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -169,8 +171,8 @@ def run(weights='yolov5s.pt', # model.pt path(s)
169
  else: # stream
170
  fps, w, h = 30, im0.shape[1], im0.shape[0]
171
  save_path += '.mp4'
172
- vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
173
- vid_writer.write(im0)
174
 
175
  if save_txt or save_img:
176
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
 
76
  modelc = load_classifier(name='resnet50', n=2) # initialize
77
  modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
78
 
79
+ # Dataloader
 
80
  if webcam:
81
  view_img = check_imshow()
82
  cudnn.benchmark = True # set True to speed up constant image size inference
83
  dataset = LoadStreams(source, img_size=imgsz, stride=stride)
84
+ bs = len(dataset) # batch_size
85
  else:
86
  dataset = LoadImages(source, img_size=imgsz, stride=stride)
87
+ bs = 1 # batch_size
88
+ vid_path, vid_writer = [None] * bs, [None] * bs
89
 
90
  # Run inference
91
  if device.type != 'cpu':
 
160
  if dataset.mode == 'image':
161
  cv2.imwrite(save_path, im0)
162
  else: # 'video' or 'stream'
163
+ if vid_path[i] != save_path: # new video
164
+ vid_path[i] = save_path
165
+ if isinstance(vid_writer[i], cv2.VideoWriter):
166
+ vid_writer[i].release() # release previous video writer
167
  if vid_cap: # video
168
  fps = vid_cap.get(cv2.CAP_PROP_FPS)
169
  w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
171
  else: # stream
172
  fps, w, h = 30, im0.shape[1], im0.shape[0]
173
  save_path += '.mp4'
174
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
175
+ vid_writer[i].write(im0)
176
 
177
  if save_txt or save_img:
178
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
utils/datasets.py CHANGED
@@ -352,7 +352,7 @@ class LoadStreams: # multiple IP or RTSP cameras
352
  return self.sources, img, img0, None
353
 
354
  def __len__(self):
355
- return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
356
 
357
 
358
  def img2label_paths(img_paths):
 
352
  return self.sources, img, img0, None
353
 
354
  def __len__(self):
355
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
356
 
357
 
358
  def img2label_paths(img_paths):