oskarastrom commited on
Commit
2ceaa4e
1 Parent(s): 7ce97ad

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -17
inference.py CHANGED
@@ -268,23 +268,20 @@ def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, c
268
 
269
  frames = [None, None]
270
 
271
- next_frame = None
272
- if i+1 < len(infer):
273
- next_frame = infer[i+1]
274
- elif has_next_batch:
275
- next_frame = inference[batch_i + 1][0]
276
-
277
- if next_frame != None:
278
- boost_frame(safe_frame, next_frame, 1, decay=conf_decay)
279
-
280
- prev_frame = None
281
- if i-1 >= 0:
282
- prev_frame = infer[i-1]
283
- elif has_prev_batch:
284
- prev_frame = inference[batch_i - 1][len(inference[batch_i - 1]) - 1]
285
-
286
- if prev_frame != None:
287
- boost_frame(safe_frame, prev_frame, -1, power=conf_power, decay=conf_decay)
288
 
289
  pbar.update(1*batch_size)
290
 
 
268
 
269
  frames = [None, None]
270
 
271
+ for dt in [-1, 1]:
272
+ idx = i+dt
273
+ temp_frame = None
274
+ if idx >= 0 and idx < len(infer):
275
+ temp_frame = infer[idx]
276
+ elif idx < 0 and has_prev_batch:
277
+ prev_batch = inference[batch_i - 1]
278
+ temp_frame = prev_batch[idx]
279
+ elif idx >= len(infer) and has_next_batch:
280
+ next_batch = inference[batch_i + 1]
281
+ temp_frame = next_batch[idx - len(infer)]
282
+
283
+ if temp_frame is not None:
284
+ boost_frame(safe_frame, temp_frame, dt, power=conf_power, decay=conf_decay)
 
 
 
285
 
286
  pbar.update(1*batch_size)
287