Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Interactive Pseudolabeling Tool for COCO Dataset | |
| Features: | |
| - Visual inspection with OpenCV | |
| - Click to remove false positive bboxes | |
| - Compare predictions with ground truth | |
| - Save pseudolabels to COCO format | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import shutil | |
| import time | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| class NumpyEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| elif isinstance(obj, np.floating): | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| return super(NumpyEncoder, self).default(obj) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class BBox: | |
| """Bounding box with metadata""" | |
| x1: float | |
| y1: float | |
| x2: float | |
| y2: float | |
| score: float | |
| category_id: int = 0 | |
| source: str = "predicted" # "predicted", "original", "manual" | |
| id: Optional[int] = None | |
| area: Optional[float] = None | |
| def to_coco(self): | |
| return [self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1] | |
| def from_coco(cls, bbox_list, score=1.0, source="original", **kwargs): | |
| x, y, w, h = bbox_list | |
| return cls(x, y, x + w, y + h, score, source=source, **kwargs) | |
| def iou(self, other): | |
| x1 = max(self.x1, other.x1) | |
| y1 = max(self.y1, other.y1) | |
| x2 = min(self.x2, other.x2) | |
| y2 = min(self.y2, other.y2) | |
| if x2 < x1 or y2 < y1: | |
| return 0.0 | |
| intersection = (x2 - x1) * (y2 - y1) | |
| area1 = (self.x2 - self.x1) * (self.y2 - self.y1) | |
| area2 = (other.x2 - other.x1) * (other.y2 - other.y1) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def overflow_area(self, other): | |
| pred_x1, pred_y1, pred_x2, pred_y2 = self.x1, self.y1, self.x2, self.y2 | |
| gt_x1, gt_y1, gt_x2, gt_y2 = other.x1, other.y1, other.x2, other.y2 | |
| overflow_left = max(0, gt_x1 - pred_x1) | |
| overflow_top = max(0, gt_y1 - pred_y1) | |
| overflow_right = max(0, pred_x2 - gt_x2) | |
| overflow_bottom = max(0, pred_y2 - gt_y2) | |
| pred_width = pred_x2 - pred_x1 | |
| pred_height = pred_y2 - pred_y1 | |
| overflow_area = (overflow_left * pred_height + | |
| overflow_right * pred_height + | |
| overflow_top * pred_width + | |
| overflow_bottom * pred_width) | |
| gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1) | |
| return overflow_area / (gt_area + 1e-6) | |
| def calc_area(self, hw): | |
| if hw is not None: | |
| h, w = hw | |
| return (self.x2 / w - self.x1 / w) * (self.y2 / h - self.y1 / h) | |
| else: | |
| return (self.x2 - self.x1) * (self.y2 - self.y1) | |
| class PedestrianDetector: | |
| def __init__(self, | |
| model_paths, | |
| target_size=(800, 1333), | |
| tta=False, | |
| tile_grid=(1, 1), | |
| nms_thr=0.5): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.target_size = target_size | |
| self.tta = tta | |
| self.tile_grid = tuple(tile_grid) | |
| self.nms_thr = nms_thr | |
| self.models = [ | |
| self._load_model(model_path) | |
| for model_path in model_paths | |
| ] | |
| self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) | |
| self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32) | |
| def _load_model(self, model_path): | |
| assert model_path.endswith('.pt') or '_traced' in model_path, \ | |
| f"Expected a traced .pt model, got {model_path}" | |
| m = torch.jit.load(model_path, map_location=self.device if "cpu" not in model_path else "cpu") | |
| m.eval() | |
| return m.to(self.device) if "cpu" not in model_path else m | |
| def _preprocess_image(self, image, model_name: str): | |
| target_size = self.target_size if model_name != "deim" else (800, 1024) | |
| device = self.device if model_name != "deim" else "cpu" | |
| h, w = image.shape[:2] | |
| scale = min(target_size[0] / h, target_size[1] / w) | |
| new_h, new_w = int(h * scale), int(w * scale) | |
| resized = cv2.resize(image, (new_w, new_h)) | |
| pad_h = target_size[0] - new_h | |
| pad_w = target_size[1] - new_w | |
| padded = cv2.copyMakeBorder( | |
| resized, 0, pad_h, 0, pad_w, | |
| cv2.BORDER_CONSTANT, value=(0, 0, 0) | |
| ) | |
| if model_name == "deim": | |
| norm = padded.astype(np.float32) / 255 | |
| else: | |
| norm = (padded.astype(np.float32) - self.mean) / self.std | |
| tensor = torch.from_numpy(norm.transpose(2, 0, 1))[None].float().to(device) | |
| return tensor, scale | |
| def _postprocess_detections(self, output, thresh): | |
| bboxes, _ = output | |
| b_np = bboxes[0].cpu().numpy() | |
| scores = b_np[:, 4] | |
| mask = scores >= thresh | |
| if not mask.any(): | |
| return np.zeros((0, 5), dtype=np.float32) | |
| valid = b_np[mask] | |
| return valid | |
| def _postprocess_detections_deim(self, output, hw, thresh): | |
| h, w = hw | |
| scores, bboxes = output | |
| b_np = bboxes[0].cpu().numpy() | |
| s_np = scores.sigmoid()[0].cpu().numpy() | |
| mask = (s_np >= thresh).squeeze() | |
| if not mask.any(): | |
| return np.zeros((0, 5), dtype=np.float32) | |
| valid = b_np[mask] | |
| valid = valid | |
| cx, cy, box_w, box_h = valid[:, 0], valid[:, 1], valid[:, 2], valid[:, 3] | |
| x1 = cx - box_w / 2 | |
| y1 = cy - box_h / 2 | |
| x2 = cx + box_w / 2 | |
| y2 = cy + box_h / 2 | |
| valid_xyxy = np.stack([x1, y1, x2, y2], axis=1) * [w, h, w, h] | |
| return np.concatenate([valid_xyxy, s_np[mask]], axis=1) | |
| def _rescale_bboxes(self, dets, scale): | |
| if dets.shape[0] == 0: | |
| return dets | |
| dets[:, :4] = dets[:, :4] / scale | |
| return dets | |
| def _nms(dets, iou_thr): | |
| if dets.shape[0] == 0 or iou_thr <= 0: | |
| return dets | |
| x1 = dets[:, 0] | |
| y1 = dets[:, 1] | |
| x2 = dets[:, 2] | |
| y2 = dets[:, 3] | |
| scores = dets[:, 4] | |
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |
| order = scores.argsort()[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(i) | |
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |
| w = np.maximum(0.0, xx2 - xx1 + 1) | |
| h = np.maximum(0.0, yy2 - yy1 + 1) | |
| inter = w * h | |
| iou = inter / (areas[i] + areas[order[1:]] - inter) | |
| inds = np.where(iou <= iou_thr)[0] | |
| order = order[inds + 1] | |
| return dets[keep] | |
| def predict(self, image, thresh: float, model_num: Optional[int] = None): | |
| if isinstance(image, str): | |
| img = cv2.imread(image) | |
| if img is None: | |
| raise ValueError(f"Could not load image: {image}") | |
| else: | |
| img = image | |
| all_preds = [] | |
| if model_num is None: | |
| models = self.models | |
| else: | |
| models = [self.models[model_num]] | |
| for model in models: | |
| try: | |
| name = model.model.original_name.lower() | |
| except: | |
| name = "other" | |
| tensor, scale = self._preprocess_image(img, name) | |
| with torch.no_grad(): | |
| out = model(tensor) | |
| if "deim" in name: | |
| dets = self._postprocess_detections_deim(out, (800, 1024), thresh) | |
| else: | |
| dets = self._postprocess_detections(out, thresh) | |
| if dets.shape[0] > 0: | |
| dets = self._rescale_bboxes(dets, scale) | |
| all_preds.append(dets) | |
| if not all_preds: | |
| return [] | |
| merged = np.vstack(all_preds) | |
| if self.nms_thr > 0: | |
| merged = self._nms(merged, self.nms_thr) | |
| bboxes = [] | |
| for det in merged: | |
| x1, y1, x2, y2, score = det | |
| bboxes.append(BBox(x1, y1, x2, y2, score, source="predicted")) | |
| return bboxes | |
| class InteractivePseudolabeler: | |
| def __init__(self, | |
| images_path: str, | |
| annotations_json: str, | |
| model_paths: List[str], | |
| iou_threshold: float = 0.5, | |
| dataset_filter: str = None, | |
| force_repredict: bool = False, | |
| refine_mode: bool = False): | |
| self.images_dir = Path(images_path) | |
| self.annotations_json = Path(annotations_json) | |
| self.dataset_filter = dataset_filter | |
| self.force_repredict = force_repredict | |
| self.refine_mode = refine_mode | |
| self.progress_file = self.annotations_json.parent / "pseudolabel_progress.json" | |
| backup_file = self.annotations_json.parent / f"{self.annotations_json.stem}_backup.json" | |
| if not backup_file.exists(): | |
| shutil.copy2(self.annotations_json, backup_file) | |
| logger.info(f"Created backup: {backup_file}") | |
| logger.info(f"Loading annotations from {self.annotations_json}") | |
| with open(self.annotations_json, 'r') as f: | |
| self.coco_data = json.load(f) | |
| # Keep a pristine copy of original annotations that we never modify | |
| # This ensures we can always access unvisited annotations | |
| self.original_coco_annotations = self.coco_data['annotations'].copy() | |
| all_images_unfiltered = self.coco_data['images'] | |
| if self.dataset_filter: | |
| logger.info(f"Filtering images by dataset: {self.dataset_filter}") | |
| self.all_images = [ | |
| img for img in all_images_unfiltered | |
| if img.get('dataset', '').lower() == self.dataset_filter.lower() | |
| ] | |
| if not self.all_images: | |
| self.all_images = [ | |
| img for img in all_images_unfiltered | |
| if self.dataset_filter.lower() in img.get('dataset', '').lower() | |
| ] | |
| logger.info(f"Found {len(self.all_images)} images from {self.dataset_filter}") | |
| else: | |
| self.all_images = all_images_unfiltered | |
| self.all_annotations = self.coco_data['annotations'].copy() # Make a copy to avoid reference issues | |
| self.image_id_to_anns = defaultdict(list) | |
| for ann in self.all_annotations: | |
| self.image_id_to_anns[ann['image_id']].append(ann) | |
| # Log statistics about existing annotations | |
| total_anns = len(self.all_annotations) | |
| pseudo_anns = sum(1 for ann in self.all_annotations if ann.get('is_pseudolabel', False)) | |
| logger.info(f"Loaded {total_anns} annotations ({pseudo_anns} pseudolabels)") | |
| # Log distribution of pseudolabels across images for debugging | |
| images_with_pseudo = set() | |
| for ann in self.all_annotations: | |
| if ann.get('is_pseudolabel', False): | |
| images_with_pseudo.add(ann['image_id']) | |
| if images_with_pseudo: | |
| logger.info(f" Pseudolabels found in {len(images_with_pseudo)} images") | |
| self.image_id_to_info = {img['id']: img for img in self.all_images} | |
| logger.info("Initializing detector models...") | |
| self.detector = PedestrianDetector( | |
| model_paths=model_paths, | |
| nms_thr=0.8 | |
| ) | |
| self.model_paths = model_paths | |
| self.refine_score_threshold = 0.01 | |
| self.predict_score_threshold = 0.3 | |
| self.current_idx = 0 | |
| self.current_image = None | |
| self.current_bboxes = [] | |
| self.original_bboxes = [] | |
| self.predicted_bboxes = [] | |
| self.similarity_scores = {} | |
| self.using_existing_pseudolabels = False | |
| # T press counter and threshold steps | |
| self.t_press_count = 0 | |
| self.threshold_steps = [0.3, 0.25, 0.2, 0.15, 0.1, 0.05, 0.01, 0.005, 0.001] | |
| self.processed_images = set() # Images that have been processed (ever) | |
| self.session_visited_images = set() # Images visited in THIS session only | |
| self.working_data = { | |
| "info": self.coco_data.get("info", {}), | |
| "licenses": self.coco_data.get("licenses", []), | |
| "categories": self.coco_data.get("categories", []), | |
| "images": self.coco_data['images'].copy(), | |
| "annotations": [] # Start empty - will only contain annotations for visited images | |
| } | |
| self.annotation_id_counter = max([ann['id'] for ann in self.all_annotations], default=0) + 1 | |
| self.load_progress() | |
| self.window_name = "Interactive Pseudolabeler" | |
| self.mouse_x = 0 | |
| self.mouse_y = 0 | |
| self.show_original = True | |
| self.show_predicted = True | |
| self.iou_threshold = iou_threshold | |
| self.hovered_bbox = None # Track which bbox is under mouse | |
| self.auto_mode = False # Automatic processing mode | |
| self.auto_predict_mode = False # Automatic predict and process mode | |
| self.auto_mode_delay = 500 # Delay in ms between auto-processing images | |
| def load_progress(self): | |
| if self.progress_file.exists(): | |
| with open(self.progress_file, 'r') as f: | |
| progress = json.load(f) | |
| self.processed_images = set(progress.get('processed_images', [])) | |
| self.current_idx = progress.get('current_idx', 0) | |
| logger.info(f"Loaded progress: {len(self.processed_images)} images processed") | |
| def save_progress(self): | |
| progress = { | |
| 'processed_images': list(self.processed_images), | |
| 'current_idx': self.current_idx, | |
| 'timestamp': time.time() | |
| } | |
| with open(self.progress_file, 'w') as f: | |
| json.dump(progress, f, indent=2, cls=NumpyEncoder) | |
| def reload_annotations_from_disk(self): | |
| """Reload annotations from disk to ensure consistency""" | |
| logger.info("Reloading annotations from disk...") | |
| with open(self.annotations_json, 'r') as f: | |
| self.coco_data = json.load(f) | |
| self.all_annotations = self.coco_data['annotations'].copy() | |
| # Update the original annotations reference | |
| self.original_coco_annotations = self.coco_data['annotations'].copy() | |
| # Rebuild image_id_to_anns | |
| self.image_id_to_anns.clear() | |
| for ann in self.all_annotations: | |
| self.image_id_to_anns[ann['image_id']].append(ann) | |
| # Only update working_data annotations for images visited in THIS session | |
| # This preserves the session's work while keeping unvisited images untouched | |
| visited_anns = [ | |
| ann for ann in self.all_annotations | |
| if ann['image_id'] in self.session_visited_images | |
| ] | |
| self.working_data['annotations'] = visited_anns | |
| logger.info(f"Reloaded {len(self.all_annotations)} annotations from disk") | |
| logger.info(f" Working data contains {len(self.working_data['annotations'])} annotations for {len(self.session_visited_images)} session-visited images") | |
| def validate_annotations(self): | |
| """Validate that annotations are consistent across data structures""" | |
| issues = [] | |
| # Check if all annotations in working_data are in image_id_to_anns | |
| working_ids = {(ann['id'], ann['image_id']) for ann in self.working_data['annotations']} | |
| mapped_ids = set() | |
| for img_id, anns in self.image_id_to_anns.items(): | |
| for ann in anns: | |
| mapped_ids.add((ann['id'], ann['image_id'])) | |
| missing_in_map = working_ids - mapped_ids | |
| if missing_in_map: | |
| issues.append(f"Annotations in working_data but not in image_id_to_anns: {missing_in_map}") | |
| # Check pseudolabel counts | |
| pseudo_working = sum(1 for ann in self.working_data['annotations'] if ann.get('is_pseudolabel', False)) | |
| pseudo_all = sum(1 for ann in self.all_annotations if ann.get('is_pseudolabel', False)) | |
| logger.debug(f"Validation: {pseudo_working} pseudolabels in working_data, {pseudo_all} in all_annotations") | |
| if issues: | |
| logger.warning(f"Validation issues found: {issues}") | |
| return len(issues) == 0 | |
| def calculate_similarity_score(self, predicted: List[BBox], original: List[BBox]) -> float: | |
| if not predicted or not original: | |
| return 0.0 | |
| scores = [] | |
| for pred in predicted: | |
| best_iou = max([pred.iou(orig) for orig in original], default=0.0) | |
| scores.append(best_iou * pred.score) | |
| return np.mean(scores) if scores else 0.0 | |
| def refine_predictions(self, predicted: List[BBox], original: List[BBox], hw) -> List[BBox]: | |
| if not predicted or not original: | |
| return original.copy() if original else [] | |
| refined = [] | |
| used_predictions = set() | |
| matched_gt_indices = set() | |
| for gt_idx, gt_bbox in enumerate(original): | |
| best_match = None | |
| best_score = -10000 | |
| for i, pred_bbox in enumerate(predicted): | |
| if i in used_predictions: | |
| continue | |
| iou = gt_bbox.iou(pred_bbox) | |
| if iou >= 0.01: | |
| overflow_area = pred_bbox.overflow_area(gt_bbox) | |
| combined_score = iou + pred_bbox.score - overflow_area | |
| if combined_score > best_score: | |
| best_score = combined_score | |
| best_match = (i, pred_bbox) | |
| if best_match and best_match[1].score >= self.refine_score_threshold: | |
| used_predictions.add(best_match[0]) | |
| matched_gt_indices.add(gt_idx) | |
| refined.append(best_match[1]) | |
| logger.info( | |
| f"Refined: GT bbox matched with prediction (conf={best_match[1].score:.3f}, iou={gt_bbox.iou(best_match[1]):.3f})") | |
| # Keep unmatched GT bboxes | |
| unmatched_gt = [gt for idx, gt in enumerate(original) if idx not in matched_gt_indices] | |
| logger.info(f"Refinement: {len(original)} GT boxes -> {len(refined)} refined predictions + {len(unmatched_gt)} unmatched GT boxes") | |
| # Return both refined predictions and unmatched GT bboxes | |
| return refined + unmatched_gt | |
| def load_image(self, idx: int, auto_predict: bool = True): | |
| if idx < 0 or idx >= len(self.all_images): | |
| return False | |
| # Save current annotations before switching images (if we have a current image) | |
| if self.current_idx >= 0 and self.current_idx < len(self.all_images) and self.current_image is not None: | |
| self.save_current_annotations() | |
| self.current_idx = idx | |
| img_info = self.all_images[idx] | |
| img_path = self.images_dir / img_info['file_name'] | |
| # Reset T press counter when switching to a new image | |
| self.t_press_count = 0 | |
| if not img_path.exists(): | |
| logger.warning(f"Image not found: {img_path}") | |
| return False | |
| self.current_image = cv2.imread(str(img_path)) | |
| if self.current_image is None: | |
| logger.warning(f"Failed to load image: {img_path}") | |
| return False | |
| self.original_bboxes = [] | |
| self.predicted_bboxes = [] | |
| existing_pseudolabels = [] | |
| # Check if this image has been processed before | |
| # Use image_id_to_anns which should be kept in sync with saved data | |
| current_anns = self.image_id_to_anns.get(img_info['id'], []) | |
| # Log for debugging | |
| pseudo_count = sum(1 for ann in current_anns if ann.get('is_pseudolabel', False)) | |
| if pseudo_count > 0: | |
| logger.debug(f"Found {pseudo_count} existing pseudolabels for image {img_info['id']}") | |
| for ann in current_anns: | |
| is_pseudo = ann.get('is_pseudolabel', False) | |
| if is_pseudo: | |
| bbox = BBox.from_coco( | |
| ann['bbox'], | |
| score=ann.get('confidence', ann.get('score', 0.5)), | |
| source="predicted", | |
| id=ann['id'], | |
| category_id=ann.get('category_id', 0), | |
| area=ann.get('area') | |
| ) | |
| existing_pseudolabels.append(bbox) | |
| else: | |
| bbox = BBox.from_coco( | |
| ann['bbox'], | |
| score=1.0, | |
| source="original", | |
| id=ann['id'], | |
| category_id=ann.get('category_id', 0), | |
| area=ann.get('area') | |
| ) | |
| self.original_bboxes.append(bbox) | |
| if existing_pseudolabels and not self.force_repredict: | |
| logger.info(f"Image already has {len(existing_pseudolabels)} pseudolabels") | |
| self.predicted_bboxes = existing_pseudolabels | |
| self.using_existing_pseudolabels = True | |
| elif auto_predict: | |
| if existing_pseudolabels and self.force_repredict: | |
| logger.info(f"Force re-predicting (ignoring {len(existing_pseudolabels)} existing pseudolabels)") | |
| else: | |
| logger.info(f"Generating predictions for {img_info['file_name']}...") | |
| if self.refine_mode and self.original_bboxes: | |
| predicted_bboxes = self.detector.predict(self.current_image, self.refine_score_threshold) | |
| self.using_existing_pseudolabels = False | |
| logger.info(f"Applying refinement mode: {len(predicted_bboxes)} predictions -> filtering...") | |
| refined_results = self.refine_predictions(predicted_bboxes, self.original_bboxes, hw=self.current_image.shape[:2]) | |
| self.predicted_bboxes += [bbox for bbox in refined_results if bbox.source == "predicted"] | |
| self.original_bboxes = [bbox for bbox in refined_results if bbox.source == "original"] | |
| elif self.refine_mode and not self.original_bboxes: | |
| logger.info(f"Skipping refinement mode: No GT bboxes available for refinement") | |
| else: | |
| logger.info(f"Skipping prediction for {img_info['file_name']} (manual mode)") | |
| self.using_existing_pseudolabels = bool(existing_pseudolabels) | |
| similarity = self.calculate_similarity_score(self.predicted_bboxes, self.original_bboxes) | |
| self.similarity_scores[img_info['id']] = similarity | |
| self.current_bboxes = self.original_bboxes + self.predicted_bboxes | |
| logger.info(f"Image {idx + 1}/{len(self.all_images)}: " | |
| f"{len(self.original_bboxes)} original, " | |
| f"{len(self.predicted_bboxes)} predicted, " | |
| f"similarity: {similarity:.3f}") | |
| return True | |
| def draw_bboxes(self, img): | |
| """Draw bounding boxes on image""" | |
| vis_img = img.copy() | |
| hovered_candidates = [] | |
| for bbox in self.current_bboxes: | |
| x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2) | |
| if x1 <= self.mouse_x <= x2 and y1 <= self.mouse_y <= y2: | |
| area = (x2 - x1) * (y2 - y1) | |
| hovered_candidates.append((area, bbox)) | |
| self.hovered_bbox = None | |
| if hovered_candidates: | |
| hovered_candidates.sort(key=lambda x: x[0]) | |
| self.hovered_bbox = hovered_candidates[0][1] | |
| for bbox in self.current_bboxes: | |
| if bbox.source == "original" and not self.show_original: | |
| continue | |
| if bbox.source == "predicted" and not self.show_predicted: | |
| continue | |
| x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2) | |
| is_hovered = (bbox == self.hovered_bbox) | |
| if bbox.source == "original": | |
| color = (0, 255, 0) # Green for original | |
| label = f"GT" | |
| else: | |
| # Color based on score | |
| if bbox.score > 0.7: | |
| color = (255, 0, 0) # Blue for high confidence | |
| elif bbox.score > 0.5: | |
| color = (0, 165, 255) # Orange for medium | |
| else: | |
| color = (0, 0, 255) # Red for low | |
| label = f"{bbox.score:.2f}" | |
| thickness = 3 if is_hovered else 2 | |
| if is_hovered: | |
| color = (0, 255, 255) # Yellow for hovered | |
| cv2.rectangle(vis_img, (x1 - 1, y1 - 1), (x2 + 1, y2 + 1), (0, 0, 0), thickness + 1) | |
| cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, thickness) | |
| if is_hovered: | |
| area = (x2 - x1) * (y2 - y1) | |
| label = f"{label} [REMOVE] A:{area}" | |
| label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(vis_img, (x1, y1 - label_size[1] - 4), | |
| (x1 + label_size[0], y1), color, -1) | |
| cv2.putText(vis_img, label, (x1, y1 - 2), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| if self.hovered_bbox: | |
| cv2.drawMarker(vis_img, (self.mouse_x, self.mouse_y), | |
| (0, 255, 255), cv2.MARKER_CROSS, 15, 2) | |
| vis_img = self.draw_info_panel(vis_img) | |
| return vis_img | |
| def draw_info_panel(self, img): | |
| h, w = img.shape[:2] | |
| panel_height = 140 | |
| panel = np.zeros((panel_height, w, 3), dtype=np.uint8) | |
| panel[:] = (40, 40, 40) | |
| current_img = self.all_images[self.current_idx] | |
| dataset_name = current_img.get('dataset', 'unknown') | |
| # Info text | |
| pseudo_status = "LOADED" if self.using_existing_pseudolabels else ( | |
| "REFINED" if self.refine_mode else "PREDICTED") | |
| mode_indicator = " [REFINE MODE]" if self.refine_mode else "" | |
| if self.auto_mode: | |
| mode_indicator += " [AUTO MODE ACTIVE]" | |
| if self.auto_predict_mode: | |
| mode_indicator += " [AUTO PREDICT MODE ACTIVE]" | |
| # Calculate current threshold for display | |
| threshold_index = min(self.t_press_count, len(self.threshold_steps) - 1) | |
| current_threshold = self.threshold_steps[threshold_index] | |
| info_lines = [ | |
| f"Dataset: {dataset_name} | File: {current_img.get('file_name', 'unknown')}{mode_indicator}", | |
| f"Image {self.current_idx + 1}/{len(self.all_images)} | " | |
| f"Original: {len(self.original_bboxes)} | " | |
| f"Pseudo [{pseudo_status}]: {len(self.predicted_bboxes)} | " | |
| f"Current: {len(self.current_bboxes)}", | |
| f"Similarity: {self.similarity_scores.get(self.all_images[self.current_idx]['id'], 0):.3f} | " | |
| f"Processed: {len(self.processed_images)} | " | |
| f"T presses: {self.t_press_count} | Threshold: {current_threshold:.3f}", | |
| "Controls: Click=Remove | O=Toggle Original | P=Toggle Predicted | T=Trigger Predict | " | |
| "R=Remove All Predicted | G=Remove All Original | N=No GT & Next | M=Predict,No GT & Next", | |
| "Navigation: A/D=Prev/Next (no predict) | Space=Next (auto-predict) | W=Write to Disk | Q=Quit | J=Jump", | |
| "Auto Modes: Z=Toggle Auto Mode (use existing) | X=Toggle Auto Predict Mode (force predict)" | |
| ] | |
| y_offset = 20 | |
| for line in info_lines: | |
| cv2.putText(panel, line, (10, y_offset), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| y_offset += 25 | |
| # Add auto mode indicator | |
| if self.auto_mode or self.auto_predict_mode: | |
| # Add flashing border to indicate auto mode | |
| h_total, w_total = combined.shape[:2] if 'combined' in locals() else (h + panel_height, w) | |
| if self.auto_predict_mode: | |
| border_color = (255, 0, 255) if (time.time() * 2) % 2 < 1 else (255, 100, 255) # Magenta for predict mode | |
| auto_text = "AUTO PREDICT MODE" | |
| else: | |
| border_color = (0, 0, 255) if (time.time() * 2) % 2 < 1 else (0, 100, 255) # Red for normal auto | |
| auto_text = "AUTO MODE ACTIVE" | |
| cv2.rectangle(panel, (0, 0), (w-1, panel_height-1), border_color, 3) | |
| # Add large mode text | |
| font_scale = 1.0 | |
| thickness = 2 | |
| text_size, _ = cv2.getTextSize(auto_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) | |
| text_x = w - text_size[0] - 20 | |
| text_y = 30 | |
| cv2.putText(panel, auto_text, (text_x, text_y), | |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, border_color, thickness) | |
| combined = np.vstack([img, panel]) | |
| return combined | |
| def mouse_callback(self, event, x, y, flags, param): | |
| """Handle mouse events""" | |
| self.mouse_x = x | |
| self.mouse_y = y | |
| # Update display on mouse move to show hover effect | |
| if event == cv2.EVENT_MOUSEMOVE: | |
| pass # The draw_bboxes function will handle the hover effect | |
| if event == cv2.EVENT_LBUTTONDOWN: | |
| # Find all bboxes that contain the clicked point | |
| clicked_bboxes = [] | |
| for i, bbox in enumerate(self.current_bboxes): | |
| x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2) | |
| if x1 <= x <= x2 and y1 <= y <= y2: | |
| # Calculate area | |
| area = (x2 - x1) * (y2 - y1) | |
| clicked_bboxes.append((area, i, bbox)) | |
| if clicked_bboxes: | |
| # Sort by area (smallest first) and remove the smallest one | |
| clicked_bboxes.sort(key=lambda x: x[0]) | |
| area, idx, bbox_to_remove = clicked_bboxes[0] | |
| # Remove the bbox with smallest area | |
| removed = self.current_bboxes.pop(idx) | |
| logger.info(f"Removed {removed.source} bbox with score {removed.score:.3f} (area: {area})") | |
| # Also remove from source lists | |
| if removed.source == "original" and removed in self.original_bboxes: | |
| self.original_bboxes.remove(removed) | |
| elif removed.source == "predicted" and removed in self.predicted_bboxes: | |
| self.predicted_bboxes.remove(removed) | |
| # Auto-save to memory after modification | |
| self.save_current_annotations() | |
| def save_current_annotations(self): | |
| img_info = self.all_images[self.current_idx] | |
| self.processed_images.add(img_info['id']) | |
| self.session_visited_images.add(img_info['id']) # Track that we visited this in current session | |
| # Clear existing annotations for this image from working_data only | |
| # working_data should only contain annotations for visited images | |
| self.working_data['annotations'] = [ | |
| ann for ann in self.working_data['annotations'] | |
| if ann['image_id'] != img_info['id'] | |
| ] | |
| # Update image_id_to_anns for this specific image | |
| self.image_id_to_anns[img_info['id']] = [] | |
| # Update all_annotations - remove old annotations for this image | |
| # This is safe because we're only updating visited images | |
| self.all_annotations = [ | |
| ann for ann in self.all_annotations | |
| if ann['image_id'] != img_info['id'] | |
| ] | |
| for bbox in self.current_bboxes: | |
| bbox_coco = bbox.to_coco() | |
| bbox_coco = [float(x) for x in bbox_coco] | |
| ann = { | |
| 'id': int(bbox.id) if bbox.id else self.annotation_id_counter, | |
| 'image_id': int(img_info['id']), | |
| 'category_id': int(bbox.category_id), | |
| 'bbox': bbox_coco, | |
| 'area': float(bbox.area) if bbox.area else float((bbox.x2 - bbox.x1) * (bbox.y2 - bbox.y1)), | |
| 'segmentation': [], | |
| 'iscrowd': 0, | |
| 'is_pseudolabel': bbox.source == 'predicted', # Mark pseudolabels | |
| 'confidence': float(bbox.score) if bbox.source == 'predicted' else 1.0, # Store confidence | |
| 'verified': True | |
| } | |
| if bbox.source == 'original' and bbox.id: | |
| # Find original annotation to preserve all fields | |
| for orig_ann in self.coco_data['annotations']: # Use original coco_data | |
| if orig_ann['id'] == bbox.id: | |
| for key, value in orig_ann.items(): | |
| if key not in ann: | |
| ann[key] = value | |
| break | |
| self.working_data['annotations'].append(ann) | |
| # Update all data structures to ensure persistence | |
| self.image_id_to_anns[img_info['id']].append(ann) | |
| self.all_annotations.append(ann) | |
| if not bbox.id: | |
| bbox.id = self.annotation_id_counter # Assign ID to bbox for consistency | |
| self.annotation_id_counter += 1 | |
| n_original = sum(1 for b in self.current_bboxes if b.source == 'original') | |
| n_pseudo = sum(1 for b in self.current_bboxes if b.source == 'predicted') | |
| logger.info(f"Saved {len(self.current_bboxes)} annotations for image {img_info['id']} " | |
| f"(original: {n_original}, pseudo: {n_pseudo})") | |
| # Log detailed info for debugging | |
| logger.debug(f" working_data has {len(self.working_data['annotations'])} total annotations") | |
| logger.debug(f" image_id_to_anns[{img_info['id']}] has {len(self.image_id_to_anns[img_info['id']])} annotations") | |
| def write_to_disk(self): | |
| """Write updated annotations back to the original file""" | |
| # Merge processed annotations with unprocessed ones | |
| final_data = self.coco_data.copy() | |
| # CRITICAL: Only update images visited in THIS session, not all processed images | |
| # This prevents losing annotations for images processed in previous sessions | |
| visited_img_ids = self.session_visited_images | |
| # Keep all annotations for images that weren't visited in THIS session | |
| # Use the ORIGINAL annotations, not the potentially modified coco_data | |
| # This includes existing pseudolabels from previous sessions | |
| unvisited_annotations = [ | |
| ann for ann in self.original_coco_annotations | |
| if ann['image_id'] not in visited_img_ids | |
| ] | |
| # Get annotations for visited images from working_data | |
| # These are the updated annotations from the current session | |
| visited_annotations = [ | |
| ann for ann in self.working_data['annotations'] | |
| if ann['image_id'] in visited_img_ids | |
| ] | |
| # Combine annotations: unvisited (preserved) + visited (updated) | |
| final_data['annotations'] = unvisited_annotations + visited_annotations | |
| # Calculate statistics | |
| total_anns = len(final_data['annotations']) | |
| pseudo_anns = sum(1 for ann in final_data['annotations'] | |
| if ann.get('is_pseudolabel', False)) | |
| original_anns = total_anns - pseudo_anns | |
| # Calculate preserved vs updated statistics | |
| preserved_anns = len(unvisited_annotations) | |
| preserved_pseudo = sum(1 for ann in unvisited_annotations if ann.get('is_pseudolabel', False)) | |
| updated_anns = len(visited_annotations) | |
| updated_pseudo = sum(1 for ann in visited_annotations if ann.get('is_pseudolabel', False)) | |
| # Debug logging | |
| logger.debug(f"Write to disk debug:") | |
| logger.debug(f" Original annotations count: {len(self.original_coco_annotations)}") | |
| logger.debug(f" Session visited image IDs: {len(visited_img_ids)} images") | |
| logger.debug(f" Total processed images (all sessions): {len(self.processed_images)} images") | |
| logger.debug(f" Unvisited annotations to preserve: {preserved_anns}") | |
| logger.debug(f" Visited annotations to update: {updated_anns}") | |
| # Add metadata about pseudolabeling | |
| if 'info' not in final_data: | |
| final_data['info'] = {} | |
| final_data['info']['pseudolabeling'] = { | |
| 'last_updated': time.strftime('%Y-%m-%d %H:%M:%S'), | |
| 'total_annotations': total_anns, | |
| 'original_annotations': original_anns, | |
| 'pseudolabeled_annotations': pseudo_anns, | |
| 'images_processed_total': len(self.processed_images), | |
| 'images_processed_this_session': len(self.session_visited_images), | |
| 'models_used': [os.path.basename(m) for m in self.model_paths], | |
| 'refine_mode': self.refine_mode, | |
| 'iou_threshold': self.iou_threshold if self.refine_mode else None | |
| } | |
| # Save to original file | |
| with open(self.annotations_json, 'w') as f: | |
| json.dump(final_data, f, indent=2, cls=NumpyEncoder) | |
| # CRITICAL FIX: Update internal data structures to reflect saved state | |
| self.coco_data['annotations'] = final_data['annotations'].copy() | |
| self.all_annotations = final_data['annotations'].copy() | |
| # Update the original annotations to reflect the saved state | |
| # This becomes the new baseline for future saves | |
| self.original_coco_annotations = final_data['annotations'].copy() | |
| # Rebuild image_id_to_anns with updated annotations | |
| self.image_id_to_anns.clear() | |
| for ann in self.all_annotations: | |
| self.image_id_to_anns[ann['image_id']].append(ann) | |
| logger.info(f"Updated annotations saved to {self.annotations_json}") | |
| logger.info(f" Total: {total_anns} annotations") | |
| logger.info(f" Original: {original_anns} annotations") | |
| logger.info(f" Pseudolabeled: {pseudo_anns} annotations") | |
| logger.info(f" Images visited in this session: {len(self.session_visited_images)}") | |
| logger.info(f" Total images processed (all sessions): {len(self.processed_images)}") | |
| logger.info(f" Preserved annotations (unvisited in this session): {preserved_anns} ({preserved_pseudo} pseudolabels)") | |
| logger.info(f" Updated annotations (visited in this session): {updated_anns} ({updated_pseudo} pseudolabels)") | |
| # Save progress | |
| self.save_progress() | |
| def trigger_prediction(self): | |
| """Manually trigger prediction for current image""" | |
| img_info = self.all_images[self.current_idx] | |
| # Calculate dynamic threshold based on T press count | |
| threshold_index = min(self.t_press_count, len(self.threshold_steps) - 1) | |
| current_threshold = self.threshold_steps[threshold_index] | |
| logger.info(f"Manually triggering prediction for {img_info['file_name']}...") | |
| logger.info(f"T press #{self.t_press_count + 1}, using threshold: {current_threshold:.3f}") | |
| # Generate predictions with dynamic threshold | |
| self.predicted_bboxes = self.detector.predict(self.current_image, current_threshold) | |
| self.using_existing_pseudolabels = False | |
| # Increment T press counter after prediction | |
| self.t_press_count += 1 | |
| # Apply refinement if in refine mode and GT bboxes exist | |
| if self.refine_mode and self.original_bboxes: | |
| logger.info(f"Applying refinement mode: {len(self.predicted_bboxes)} predictions -> filtering...") | |
| refined_results = self.refine_predictions(self.predicted_bboxes, self.original_bboxes, hw=self.current_image.shape[:2]) | |
| # Separate refined predictions from unmatched GT bboxes | |
| self.predicted_bboxes = [bbox for bbox in refined_results if bbox.source == "predicted"] | |
| # Update original_bboxes to only contain unmatched GT bboxes | |
| self.original_bboxes = [bbox for bbox in refined_results if bbox.source == "original"] | |
| elif self.refine_mode and not self.original_bboxes: | |
| logger.info(f"Skipping refinement: No GT bboxes available") | |
| # Recalculate similarity score | |
| similarity = self.calculate_similarity_score(self.predicted_bboxes, self.original_bboxes) | |
| self.similarity_scores[img_info['id']] = similarity | |
| # Update current bboxes | |
| self.current_bboxes = self.original_bboxes + self.predicted_bboxes | |
| logger.info(f"Prediction complete: {len(self.predicted_bboxes)} predictions generated") | |
| # Auto-save after triggering prediction | |
| self.save_current_annotations() | |
| def jump_to_dataset(self): | |
| """Jump to a specific dataset""" | |
| # Get unique dataset names | |
| datasets = set() | |
| for img in self.all_images: | |
| if 'dataset' in img: | |
| datasets.add(img['dataset']) | |
| if not datasets: | |
| logger.warning("No dataset information found in images") | |
| return | |
| datasets = sorted(list(datasets)) | |
| # Show available datasets | |
| print("\nAvailable datasets:") | |
| for i, ds in enumerate(datasets): | |
| count = sum(1 for img in self.all_images if img.get('dataset') == ds) | |
| print(f"{i + 1}. {ds} ({count} images)") | |
| # Get user input | |
| try: | |
| choice = input("Enter dataset number (or name): ").strip() | |
| # Try to parse as number | |
| if choice.isdigit(): | |
| idx = int(choice) - 1 | |
| if 0 <= idx < len(datasets): | |
| target_dataset = datasets[idx] | |
| else: | |
| logger.warning("Invalid dataset number") | |
| return | |
| else: | |
| # Use as dataset name | |
| target_dataset = choice | |
| # Find first image from this dataset | |
| for i, img in enumerate(self.all_images): | |
| if img.get('dataset', '').lower() == target_dataset.lower(): | |
| self.load_image(i, auto_predict=False) | |
| logger.info(f"Jumped to dataset: {target_dataset}") | |
| return | |
| logger.warning(f"Dataset not found: {target_dataset}") | |
| except (ValueError, EOFError, KeyboardInterrupt): | |
| logger.info("Jump cancelled") | |
| def run(self): | |
| """Main loop""" | |
| cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL) | |
| cv2.setMouseCallback(self.window_name, self.mouse_callback) | |
| # Load first image | |
| if not self.all_images: | |
| logger.error("No images to process!") | |
| return | |
| if not self.load_image(self.current_idx, auto_predict=False): | |
| logger.error("Failed to load first image") | |
| return | |
| while True: | |
| # Draw current state | |
| vis_img = self.draw_bboxes(self.current_image) | |
| cv2.imshow(self.window_name, vis_img) | |
| # Handle keyboard input | |
| # Use shorter wait time in auto mode for responsiveness | |
| wait_time = self.auto_mode_delay if (self.auto_mode or self.auto_predict_mode) else 1 | |
| key = cv2.waitKey(wait_time) & 0xFF | |
| # Auto predict mode processing (force prediction) | |
| if self.auto_predict_mode and key == 255: # No key pressed | |
| logger.info(f"[AUTO PREDICT MODE] Processing image {self.current_idx + 1}/{len(self.all_images)}") | |
| # Force predict bboxes | |
| self.predicted_bboxes = self.detector.predict(self.current_image, self.predict_score_threshold) | |
| self.using_existing_pseudolabels = False | |
| # Apply refinement if in refine mode and GT bboxes exist | |
| if self.refine_mode and self.original_bboxes: | |
| logger.info(f"[AUTO PREDICT MODE] Applying refinement with {len(self.original_bboxes)} GT bboxes") | |
| refined_results = self.refine_predictions(self.predicted_bboxes, self.original_bboxes, hw=vis_img.shape[:2]) | |
| self.predicted_bboxes = [bbox for bbox in refined_results if bbox.source == "predicted"] | |
| elif self.refine_mode and not self.original_bboxes: | |
| logger.info(f"[AUTO PREDICT MODE] No GT bboxes - using raw predictions without refinement") | |
| # Remove all GT bboxes, keep only predictions | |
| self.current_bboxes = self.predicted_bboxes.copy() | |
| self.original_bboxes = [] | |
| # Save current annotations | |
| self.save_current_annotations() | |
| # Write to disk periodically (every 10 images) | |
| if (self.current_idx + 1) % 10 == 0: | |
| logger.info("[AUTO PREDICT MODE] Auto-saving to disk (every 10 images)") | |
| self.write_to_disk() | |
| # Move to next image | |
| if self.current_idx < len(self.all_images) - 1: | |
| self.load_image(self.current_idx + 1, auto_predict=False) # Load without auto-predict since we'll predict manually | |
| else: | |
| # Reached the end, disable auto mode | |
| self.auto_predict_mode = False | |
| logger.info("[AUTO PREDICT MODE] Reached last image, auto mode disabled") | |
| self.write_to_disk() # Final save | |
| continue # Skip the rest of the loop to process next image | |
| # Auto mode processing (use existing predictions or refine) | |
| elif self.auto_mode and key == 255: # No key pressed | |
| # Perform automatic processing: remove GT, keep predictions, save, and move to next | |
| logger.info(f"[AUTO MODE] Processing image {self.current_idx + 1}/{len(self.all_images)}") | |
| # Save current annotations | |
| self.save_current_annotations() | |
| # Write to disk periodically (every 10 images) | |
| if (self.current_idx + 1) % 10 == 0: | |
| logger.info("[AUTO MODE] Auto-saving to disk (every 10 images)") | |
| self.write_to_disk() | |
| # Move to next image | |
| if self.current_idx < len(self.all_images) - 1: | |
| self.load_image(self.current_idx + 1, auto_predict=True) | |
| else: | |
| # Reached the end, disable auto mode | |
| self.auto_mode = False | |
| logger.info("[AUTO MODE] Reached last image, auto mode disabled") | |
| self.write_to_disk() # Final save | |
| continue # Skip the rest of the loop to process next image | |
| # If any key is pressed during auto mode (except 255 which means no key), handle it | |
| if self.auto_mode and key != 255 and key != ord('z'): | |
| # Disable auto mode if any other key is pressed | |
| self.auto_mode = False | |
| logger.info("[AUTO MODE] Interrupted by user input, auto mode disabled") | |
| self.write_to_disk() # Save progress | |
| # If any key is pressed during auto predict mode (except 255 which means no key), handle it | |
| if self.auto_predict_mode and key != 255 and key != ord('x'): | |
| # Disable auto predict mode if any other key is pressed | |
| self.auto_predict_mode = False | |
| logger.info("[AUTO PREDICT MODE] Interrupted by user input, auto predict mode disabled") | |
| self.write_to_disk() # Save progress | |
| if key == ord('q'): | |
| # Quit | |
| if self.auto_mode or self.auto_predict_mode: | |
| self.auto_mode = False | |
| self.auto_predict_mode = False | |
| self.write_to_disk() | |
| break | |
| elif key == ord('d'): | |
| if self.current_idx < len(self.all_images) - 1: | |
| self.load_image(self.current_idx + 1, auto_predict=False) | |
| elif key == ord('a'): | |
| if self.current_idx > 0: | |
| self.load_image(self.current_idx - 1, auto_predict=False) | |
| elif key == ord('s'): | |
| # Manual save (though auto-save is enabled) | |
| self.save_current_annotations() | |
| logger.info("Annotations saved to memory (auto-save is enabled)") | |
| elif key == ord('w'): | |
| # Write to disk | |
| self.save_current_annotations() # Save current image first | |
| self.write_to_disk() | |
| logger.info("Annotations written to disk") | |
| # Validate after writing | |
| if not self.validate_annotations(): | |
| logger.warning("Data consistency issues detected after write") | |
| # Reload from disk to ensure consistency | |
| self.reload_annotations_from_disk() | |
| elif key == ord('o'): | |
| # Toggle original bboxes | |
| self.show_original = not self.show_original | |
| logger.info(f"Original bboxes: {'shown' if self.show_original else 'hidden'}") | |
| elif key == ord('p'): | |
| # Toggle predicted bboxes | |
| self.show_predicted = not self.show_predicted | |
| logger.info(f"Predicted bboxes: {'shown' if self.show_predicted else 'hidden'}") | |
| elif key == ord('r'): | |
| # Remove all predicted bboxes | |
| self.current_bboxes = [b for b in self.current_bboxes if b.source != "predicted"] | |
| self.predicted_bboxes = [] | |
| logger.info("Removed all predicted bboxes") | |
| # Auto-save after modification | |
| self.save_current_annotations() | |
| elif key == ord('g'): | |
| # Remove all original bboxes | |
| self.current_bboxes = [b for b in self.current_bboxes if b.source != "original"] | |
| self.original_bboxes = [] | |
| logger.info("Removed all original bboxes") | |
| # Auto-save after modification | |
| self.save_current_annotations() | |
| elif key == ord('j'): | |
| # Jump to dataset | |
| self.jump_to_dataset() | |
| elif key == ord('t'): | |
| # Trigger prediction manually | |
| self.trigger_prediction() | |
| elif key == ord(' '): | |
| # Space - quick save and next with auto-predict | |
| self.save_current_annotations() | |
| if self.current_idx < len(self.all_images) - 1: | |
| self.load_image(self.current_idx + 1, auto_predict=True) | |
| elif key == ord('f') and self.refine_mode: | |
| if self.original_bboxes: | |
| logger.info("Re-running refinement on current image...") | |
| all_predictions = self.detector.predict(self.current_image, self.refine_score_threshold) | |
| refined_results = self.refine_predictions(all_predictions, self.original_bboxes, | |
| hw=vis_img.shape[:2]) | |
| self.predicted_bboxes += [bbox for bbox in refined_results if bbox.source == "predicted"] | |
| self.original_bboxes = [bbox for bbox in refined_results if bbox.source == "original"] | |
| self.current_bboxes = self.original_bboxes + self.predicted_bboxes | |
| logger.info( | |
| f"Refinement complete: {len(all_predictions)} predictions -> {len(self.predicted_bboxes)} refined") | |
| # Auto-save after refinement | |
| self.save_current_annotations() | |
| else: | |
| logger.info("Cannot run refinement: No GT bboxes available") | |
| elif key == ord('n'): | |
| logger.info("Removing all GT bboxes, saving, and moving to next image...") | |
| self.save_current_annotations() | |
| self.write_to_disk() | |
| if self.current_idx < len(self.all_images) - 1: | |
| self.load_image(self.current_idx + 1, auto_predict=True) | |
| else: | |
| logger.info("Already at last image") | |
| elif key == ord('m'): | |
| # M - Force predict, remove GT, and move to next | |
| logger.info("Predicting bboxes, removing all GT bboxes, and moving to next image...") | |
| # Calculate dynamic threshold based on T press count | |
| threshold_index = min(self.t_press_count, len(self.threshold_steps) - 1) | |
| current_threshold = self.threshold_steps[threshold_index] | |
| # Force predict bboxes (not refine) with dynamic threshold | |
| self.predicted_bboxes = self.detector.predict(self.current_image, current_threshold) | |
| self.using_existing_pseudolabels = False | |
| # Increment T press counter after prediction | |
| self.t_press_count += 1 | |
| # Remove all GT bboxes, keep only predictions | |
| self.current_bboxes = self.predicted_bboxes.copy() | |
| self.original_bboxes = [] | |
| # Save and write to disk | |
| self.save_current_annotations() | |
| self.write_to_disk() | |
| # Move to next image | |
| if self.current_idx < len(self.all_images) - 1: | |
| self.load_image(self.current_idx + 1, auto_predict=False) | |
| else: | |
| logger.info("Already at last image") | |
| elif key == ord('z'): | |
| # Toggle auto mode (uses existing predictions or refines) | |
| self.auto_mode = not self.auto_mode | |
| if self.auto_mode: | |
| # Disable auto predict mode if it's on | |
| if self.auto_predict_mode: | |
| self.auto_predict_mode = False | |
| logger.info(f"[AUTO MODE] Enabled - will process images automatically (delay: {self.auto_mode_delay}ms)") | |
| logger.info("[AUTO MODE] Press 'Z' again to stop, or any other key to interrupt") | |
| # Ensure we have predictions for current image | |
| if not self.predicted_bboxes: | |
| self.trigger_prediction() | |
| else: | |
| logger.info("[AUTO MODE] Disabled") | |
| # Save when exiting auto mode | |
| self.save_current_annotations() | |
| self.write_to_disk() | |
| elif key == ord('x'): | |
| # Toggle auto predict mode (forces new predictions) | |
| self.auto_predict_mode = not self.auto_predict_mode | |
| if self.auto_predict_mode: | |
| # Disable normal auto mode if it's on | |
| if self.auto_mode: | |
| self.auto_mode = False | |
| logger.info(f"[AUTO PREDICT MODE] Enabled - will predict and process images automatically (delay: {self.auto_mode_delay}ms)") | |
| logger.info("[AUTO PREDICT MODE] Press 'X' again to stop, or any other key to interrupt") | |
| logger.info("[AUTO PREDICT MODE] This mode forces new predictions for each image") | |
| else: | |
| logger.info("[AUTO PREDICT MODE] Disabled") | |
| # Save when exiting auto mode | |
| self.save_current_annotations() | |
| self.write_to_disk() | |
| cv2.destroyAllWindows() | |
| # Final save | |
| if len(self.processed_images) > 0: | |
| response = input("\nSave all annotations to disk? (y/n): ") | |
| if response.lower() == 'y': | |
| self.write_to_disk() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Interactive Pseudolabeling Tool") | |
| parser.add_argument("--images", type=str, | |
| default="/mnt/archive/person_drone/wisard_coco", | |
| help="Path to images directory") | |
| parser.add_argument("--annotations", type=str, | |
| default="/mnt/archive/person_drone/wisard_coco/annotations.json", | |
| help="Path to COCO annotations JSON file (will be updated in-place)") | |
| parser.add_argument("--models", nargs="+", type=str, | |
| default=[ | |
| "model_deimhgnetV2m_cpu_v0.pt", | |
| "mmpedestron_onnx_mix_traced.pt", | |
| "mmpedestron_onnx_v2_traced.pt", | |
| ], | |
| help="Paths to traced models") | |
| parser.add_argument("--iou-thr", type=float, default=0.8, | |
| help="IoU threshold for matching predictions to GT") | |
| parser.add_argument("--filter-dataset", type=str, default=None, | |
| help="Filter images by dataset name (e.g., 'visdrone2019', 'stanford_drone')") | |
| parser.add_argument("--force-repredict", action="store_true", | |
| help="Force re-prediction even if pseudolabels already exist") | |
| parser.add_argument("--refine", action="store_true", | |
| help="Enable GT refinement mode: predict with low confidence and keep only best matches with GT boxes") | |
| parser.add_argument("--auto-delay", type=int, default=200, | |
| help="Delay in milliseconds between images in auto mode (default: 50ms)") | |
| args = parser.parse_args() | |
| model_paths = [] | |
| for model_path in args.models: | |
| if not os.path.isabs(model_path): | |
| model_path = os.path.join(os.path.dirname(__file__), model_path) | |
| model_paths.append(model_path) | |
| labeler = InteractivePseudolabeler( | |
| images_path=args.images, | |
| annotations_json=args.annotations, | |
| model_paths=model_paths, | |
| iou_threshold=args.iou_thr, | |
| dataset_filter=args.filter_dataset, | |
| force_repredict=args.force_repredict, | |
| refine_mode=args.refine | |
| ) | |
| # Set auto mode delay from command line | |
| labeler.auto_mode_delay = args.auto_delay | |
| # Run interactive session | |
| labeler.run() | |
| if __name__ == "__main__": | |
| main() | |