oskarastrom commited on
Commit
752c2e9
1 Parent(s): e8a4272

improved inference

Browse files
Files changed (3) hide show
  1. dataloader.py +28 -24
  2. inference.py +0 -1
  3. scripts/infer_frames.py +1 -1
dataloader.py CHANGED
@@ -116,31 +116,12 @@ class YOLOFrameDataset(Dataset):
116
  self.original_shape = (self.ydim, self.xdim)
117
  self.shape = np.ceil(np.array(shape) * img_size / stride + pad).astype(int) * stride
118
 
119
- self.batches = []
120
  for i in range(0,n,batch_size):
121
- self.batches.append((i, min(n, i+batch_size)))
122
-
123
- @classmethod
124
- def load_image(cls, img, img_size=896):
125
- """Loads and resizes 1 image from dataset, returns img, original hw, resized hw.
126
- Modified from ScaledYOLOv4.datasets.load_image()
127
- """
128
-
129
- h0, w0 = img.shape[:2]
130
- h1, w1 = h0, w0
131
- r = img_size / max(h0, w0)
132
- if r != 1: # always resize down, only resize up if training with augmentation
133
- interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
134
- img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
135
- h1, w1 = img.shape[:2]
136
-
137
- return img, (h0, w0), (h1, w1) # img, hw_original, hw_resized
138
 
139
- def __len__(self):
140
- return len(self.batches)
141
-
142
- def __iter__(self):
143
- for batch_idx in self.batches:
144
 
145
  batch = []
146
  labels = None
@@ -164,7 +145,30 @@ class YOLOFrameDataset(Dataset):
164
 
165
  image = torch.stack(batch)
166
 
167
- yield (image, labels, shapes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  class ARISBatchedDataset(Dataset):
170
  def __init__(self, aris_filepath, beam_width_dir, annotations_file, batch_size, num_frames_bg_subtract=1000, disable_output=False,
 
116
  self.original_shape = (self.ydim, self.xdim)
117
  self.shape = np.ceil(np.array(shape) * img_size / stride + pad).astype(int) * stride
118
 
119
+ self.batch_indices = []
120
  for i in range(0,n,batch_size):
121
+ self.batch_indices.append((i, min(n, i+batch_size)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ self.batches = []
124
+ for batch_idx in self.batch_indices:
 
 
 
125
 
126
  batch = []
127
  labels = None
 
145
 
146
  image = torch.stack(batch)
147
 
148
+ self.batches.append((image, labels, shapes))
149
+
150
+ @classmethod
151
+ def load_image(cls, img, img_size=896):
152
+ """Loads and resizes 1 image from dataset, returns img, original hw, resized hw.
153
+ Modified from ScaledYOLOv4.datasets.load_image()
154
+ """
155
+
156
+ h0, w0 = img.shape[:2]
157
+ h1, w1 = h0, w0
158
+ r = img_size / max(h0, w0)
159
+ if r != 1: # always resize down, only resize up if training with augmentation
160
+ interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
161
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
162
+ h1, w1 = img.shape[:2]
163
+
164
+ return img, (h0, w0), (h1, w1) # img, hw_original, hw_resized
165
+
166
+ def __len__(self):
167
+ return len(self.batches)
168
+
169
+ def __iter__(self):
170
+ for batch in self.batches:
171
+ yield batch
172
 
173
  class ARISBatchedDataset(Dataset):
174
  def __init__(self, aris_filepath, beam_width_dir, annotations_file, batch_size, num_frames_bg_subtract=1000, disable_output=False,
inference.py CHANGED
@@ -125,7 +125,6 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
125
  size = tuple(img.shape)
126
  nb, _, height, width = size # batch size, channels, height, width
127
 
128
- print(nb, _, height, width)
129
  # Run model & NMS
130
  with torch.no_grad():
131
  inf_out, _ = model(img, augment=False)
 
125
  size = tuple(img.shape)
126
  nb, _, height, width = size # batch size, channels, height, width
127
 
 
128
  # Run model & NMS
129
  with torch.no_grad():
130
  inf_out, _ = model(img, augment=False)
scripts/infer_frames.py CHANGED
@@ -39,7 +39,7 @@ def main(args, config={}, verbose=True):
39
 
40
  dirname = args.frames
41
 
42
- locations = ["test"]
43
  for loc in locations:
44
 
45
  in_loc_dir = os.path.join(dirname, loc)
 
39
 
40
  dirname = args.frames
41
 
42
+ locations = ["kenai-val"]
43
  for loc in locations:
44
 
45
  in_loc_dir = os.path.join(dirname, loc)