import torch import numpy as np from collections import defaultdict from mmdet.models.task_modules.assigners import BboxOverlaps2D from mmengine.structures import InstanceData def average_score_filter(instances_list): # Extract instance IDs and their scores instance_id_to_frames = defaultdict(list) instance_id_to_scores = defaultdict(list) for frame_idx, instances in enumerate(instances_list): for i, instance_id in enumerate(instances[0].pred_track_instances.instances_id): instance_id_to_frames[instance_id.item()].append(frame_idx) instance_id_to_scores[instance_id.item()].append(instances[0].pred_track_instances.scores[i].cpu().numpy()) # Compute average scores for each segment of each instance ID for instance_id, frames in instance_id_to_frames.items(): scores = np.array(instance_id_to_scores[instance_id]) # Identify segments segments = [] segment = [frames[0]] for idx in range(1, len(frames)): if frames[idx] == frames[idx - 1] + 1: segment.append(frames[idx]) else: segments.append(segment) segment = [frames[idx]] segments.append(segment) # Compute average score for each segment avg_scores = np.copy(scores) for segment in segments: segment_scores = scores[frames.index(segment[0]):frames.index(segment[-1]) + 1] avg_score = np.mean(segment_scores) avg_scores[frames.index(segment[0]):frames.index(segment[-1]) + 1] = avg_score # Update instances_list with average scores for frame_idx, avg_score in zip(frames, avg_scores): instances_list[frame_idx][0].pred_track_instances.scores[ instances_list[frame_idx][0].pred_track_instances.instances_id == instance_id] = torch.tensor(avg_score, dtype=instances_list[frame_idx][0].pred_track_instances.scores.dtype) return instances_list def moving_average_filter(instances_list, window_size=5): # Helper function to compute the moving average def smooth_bbox(bboxes, window_size): smoothed_bboxes = np.copy(bboxes) half_window = window_size // 2 for i in range(4): padded_bboxes = np.pad(bboxes[:, i], (half_window, half_window), mode='edge') smoothed_bboxes[:, i] = np.convolve(padded_bboxes, np.ones(window_size) / window_size, mode='valid') return smoothed_bboxes # Extract bounding boxes and instance IDs instance_id_to_frames = defaultdict(list) instance_id_to_bboxes = defaultdict(list) for frame_idx, instances in enumerate(instances_list): for i, instance_id in enumerate(instances[0].pred_track_instances.instances_id): instance_id_to_frames[instance_id.item()].append(frame_idx) instance_id_to_bboxes[instance_id.item()].append(instances[0].pred_track_instances.bboxes[i].cpu().numpy()) # Apply moving average filter to each segment for instance_id, frames in instance_id_to_frames.items(): bboxes = np.array(instance_id_to_bboxes[instance_id]) # Identify segments segments = [] segment = [frames[0]] for idx in range(1, len(frames)): if frames[idx] == frames[idx - 1] + 1: segment.append(frames[idx]) else: segments.append(segment) segment = [frames[idx]] segments.append(segment) # Smooth bounding boxes for each segment smoothed_bboxes = np.copy(bboxes) for segment in segments: if len(segment) >= window_size: segment_bboxes = bboxes[frames.index(segment[0]):frames.index(segment[-1]) + 1] smoothed_segment_bboxes = smooth_bbox(segment_bboxes, window_size) smoothed_bboxes[frames.index(segment[0]):frames.index(segment[-1]) + 1] = smoothed_segment_bboxes # Update instances_list with smoothed bounding boxes for frame_idx, smoothed_bbox in zip(frames, smoothed_bboxes): instances_list[frame_idx][0].pred_track_instances.bboxes[ instances_list[frame_idx][0].pred_track_instances.instances_id == instance_id] = torch.tensor(smoothed_bbox, dtype=instances_list[frame_idx][0].pred_track_instances.bboxes.dtype).to(instances_list[frame_idx][0].pred_track_instances.bboxes.device) return instances_list def identify_and_remove_giant_bounding_boxes(instances_list, image_size, size_threshold, confidence_threshold, coverage_threshold, object_num_thr=4, max_objects_in_box=6): # Initialize BboxOverlaps2D with 'iof' mode bbox_overlaps_calculator = BboxOverlaps2D() # Initialize data structures invalid_instance_ids = set() image_width, image_height = image_size two_thirds_image_area = (2 / 3) * (image_width * image_height) # Step 1: Identify giant bounding boxes and record their instance_ids for frame_idx, instances in enumerate(instances_list): bounding_boxes = instances[0].pred_track_instances.bboxes confidence_scores = instances[0].pred_track_instances.scores instance_ids = instances[0].pred_track_instances.instances_id N = bounding_boxes.size(0) for i in range(N): current_box = bounding_boxes[i] box_size = (current_box[2] - current_box[0]) * (current_box[3] - current_box[1]) if box_size < size_threshold: continue other_boxes = torch.cat([bounding_boxes[:i], bounding_boxes[i + 1:]]) other_confidences = torch.cat([confidence_scores[:i], confidence_scores[i + 1:]]) iofs = bbox_overlaps_calculator(other_boxes, current_box.unsqueeze(0), mode='iof', is_aligned=False) if iofs.numel() == 0: continue high_conf_mask = other_confidences > confidence_threshold if high_conf_mask.numel() == 0 or torch.sum(high_conf_mask) == 0: continue high_conf_masked_iofs = iofs[high_conf_mask] covered_high_conf_boxes_count = torch.sum(high_conf_masked_iofs > coverage_threshold) if covered_high_conf_boxes_count >= object_num_thr and torch.all( confidence_scores[i] < other_confidences[high_conf_mask]): invalid_instance_ids.add(instance_ids[i].item()) continue if box_size > two_thirds_image_area: invalid_instance_ids.add(instance_ids[i].item()) continue # New condition: if the bounding box contains more than 6 objects if covered_high_conf_boxes_count > max_objects_in_box: invalid_instance_ids.add(instance_ids[i].item()) continue # Remove invalid tracks for frame_idx, instances in enumerate(instances_list): valid_mask = torch.tensor( [instance_id.item() not in invalid_instance_ids for instance_id in instances[0].pred_track_instances.instances_id]) if len(valid_mask) == 0: continue new_instance_data = InstanceData() new_instance_data.bboxes = instances[0].pred_track_instances.bboxes[valid_mask] new_instance_data.scores = instances[0].pred_track_instances.scores[valid_mask] new_instance_data.instances_id = instances[0].pred_track_instances.instances_id[valid_mask] new_instance_data.labels = instances[0].pred_track_instances.labels[valid_mask] if 'masks' in instances[0].pred_track_instances: new_instance_data.masks = instances[0].pred_track_instances.masks[valid_mask] instances[0].pred_track_instances = new_instance_data return instances_list def filter_and_update_tracks(instances_list, image_size, size_threshold=10000, coverage_threshold=0.75, confidence_threshold=0.2, smoothing_window_size=5): # Step 1: Identify and remove giant bounding boxes instances_list = identify_and_remove_giant_bounding_boxes(instances_list, image_size, size_threshold, confidence_threshold, coverage_threshold) # Step 2: Smooth interpolated bounding boxes instances_list = moving_average_filter(instances_list, window_size=smoothing_window_size) # Step 3: compute the track average score instances_list = average_score_filter(instances_list) return instances_list