lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
#!/usr/bin/env python3
"""
Interactive Pseudolabeling Tool for COCO Dataset
Features:
- Visual inspection with OpenCV
- Click to remove false positive bboxes
- Compare predictions with ground truth
- Save pseudolabels to COCO format
"""
import argparse
import json
import logging
import math
import os
import shutil
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import cv2
import numpy as np
import torch
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return super(NumpyEncoder, self).default(obj)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class BBox:
"""Bounding box with metadata"""
x1: float
y1: float
x2: float
y2: float
score: float
category_id: int = 0
source: str = "predicted" # "predicted", "original", "manual"
id: Optional[int] = None
area: Optional[float] = None
def to_coco(self):
return [self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1]
@classmethod
def from_coco(cls, bbox_list, score=1.0, source="original", **kwargs):
x, y, w, h = bbox_list
return cls(x, y, x + w, y + h, score, source=source, **kwargs)
def iou(self, other):
x1 = max(self.x1, other.x1)
y1 = max(self.y1, other.y1)
x2 = min(self.x2, other.x2)
y2 = min(self.y2, other.y2)
if x2 < x1 or y2 < y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (self.x2 - self.x1) * (self.y2 - self.y1)
area2 = (other.x2 - other.x1) * (other.y2 - other.y1)
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def overflow_area(self, other):
pred_x1, pred_y1, pred_x2, pred_y2 = self.x1, self.y1, self.x2, self.y2
gt_x1, gt_y1, gt_x2, gt_y2 = other.x1, other.y1, other.x2, other.y2
overflow_left = max(0, gt_x1 - pred_x1)
overflow_top = max(0, gt_y1 - pred_y1)
overflow_right = max(0, pred_x2 - gt_x2)
overflow_bottom = max(0, pred_y2 - gt_y2)
pred_width = pred_x2 - pred_x1
pred_height = pred_y2 - pred_y1
overflow_area = (overflow_left * pred_height +
overflow_right * pred_height +
overflow_top * pred_width +
overflow_bottom * pred_width)
gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
return overflow_area / (gt_area + 1e-6)
def calc_area(self, hw):
if hw is not None:
h, w = hw
return (self.x2 / w - self.x1 / w) * (self.y2 / h - self.y1 / h)
else:
return (self.x2 - self.x1) * (self.y2 - self.y1)
class PedestrianDetector:
def __init__(self,
model_paths,
target_size=(800, 1333),
tta=False,
tile_grid=(1, 1),
nms_thr=0.5):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.target_size = target_size
self.tta = tta
self.tile_grid = tuple(tile_grid)
self.nms_thr = nms_thr
self.models = [
self._load_model(model_path)
for model_path in model_paths
]
self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
def _load_model(self, model_path):
assert model_path.endswith('.pt') or '_traced' in model_path, \
f"Expected a traced .pt model, got {model_path}"
m = torch.jit.load(model_path, map_location=self.device if "cpu" not in model_path else "cpu")
m.eval()
return m.to(self.device) if "cpu" not in model_path else m
def _preprocess_image(self, image, model_name: str):
target_size = self.target_size if model_name != "deim" else (800, 1024)
device = self.device if model_name != "deim" else "cpu"
h, w = image.shape[:2]
scale = min(target_size[0] / h, target_size[1] / w)
new_h, new_w = int(h * scale), int(w * scale)
resized = cv2.resize(image, (new_w, new_h))
pad_h = target_size[0] - new_h
pad_w = target_size[1] - new_w
padded = cv2.copyMakeBorder(
resized, 0, pad_h, 0, pad_w,
cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
if model_name == "deim":
norm = padded.astype(np.float32) / 255
else:
norm = (padded.astype(np.float32) - self.mean) / self.std
tensor = torch.from_numpy(norm.transpose(2, 0, 1))[None].float().to(device)
return tensor, scale
def _postprocess_detections(self, output, thresh):
bboxes, _ = output
b_np = bboxes[0].cpu().numpy()
scores = b_np[:, 4]
mask = scores >= thresh
if not mask.any():
return np.zeros((0, 5), dtype=np.float32)
valid = b_np[mask]
return valid
def _postprocess_detections_deim(self, output, hw, thresh):
h, w = hw
scores, bboxes = output
b_np = bboxes[0].cpu().numpy()
s_np = scores.sigmoid()[0].cpu().numpy()
mask = (s_np >= thresh).squeeze()
if not mask.any():
return np.zeros((0, 5), dtype=np.float32)
valid = b_np[mask]
valid = valid
cx, cy, box_w, box_h = valid[:, 0], valid[:, 1], valid[:, 2], valid[:, 3]
x1 = cx - box_w / 2
y1 = cy - box_h / 2
x2 = cx + box_w / 2
y2 = cy + box_h / 2
valid_xyxy = np.stack([x1, y1, x2, y2], axis=1) * [w, h, w, h]
return np.concatenate([valid_xyxy, s_np[mask]], axis=1)
def _rescale_bboxes(self, dets, scale):
if dets.shape[0] == 0:
return dets
dets[:, :4] = dets[:, :4] / scale
return dets
@staticmethod
def _nms(dets, iou_thr):
if dets.shape[0] == 0 or iou_thr <= 0:
return dets
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(iou <= iou_thr)[0]
order = order[inds + 1]
return dets[keep]
def predict(self, image, thresh: float, model_num: Optional[int] = None):
if isinstance(image, str):
img = cv2.imread(image)
if img is None:
raise ValueError(f"Could not load image: {image}")
else:
img = image
all_preds = []
if model_num is None:
models = self.models
else:
models = [self.models[model_num]]
for model in models:
try:
name = model.model.original_name.lower()
except:
name = "other"
tensor, scale = self._preprocess_image(img, name)
with torch.no_grad():
out = model(tensor)
if "deim" in name:
dets = self._postprocess_detections_deim(out, (800, 1024), thresh)
else:
dets = self._postprocess_detections(out, thresh)
if dets.shape[0] > 0:
dets = self._rescale_bboxes(dets, scale)
all_preds.append(dets)
if not all_preds:
return []
merged = np.vstack(all_preds)
if self.nms_thr > 0:
merged = self._nms(merged, self.nms_thr)
bboxes = []
for det in merged:
x1, y1, x2, y2, score = det
bboxes.append(BBox(x1, y1, x2, y2, score, source="predicted"))
return bboxes
class InteractivePseudolabeler:
def __init__(self,
images_path: str,
annotations_json: str,
model_paths: List[str],
iou_threshold: float = 0.5,
dataset_filter: str = None,
force_repredict: bool = False,
refine_mode: bool = False):
self.images_dir = Path(images_path)
self.annotations_json = Path(annotations_json)
self.dataset_filter = dataset_filter
self.force_repredict = force_repredict
self.refine_mode = refine_mode
self.progress_file = self.annotations_json.parent / "pseudolabel_progress.json"
backup_file = self.annotations_json.parent / f"{self.annotations_json.stem}_backup.json"
if not backup_file.exists():
shutil.copy2(self.annotations_json, backup_file)
logger.info(f"Created backup: {backup_file}")
logger.info(f"Loading annotations from {self.annotations_json}")
with open(self.annotations_json, 'r') as f:
self.coco_data = json.load(f)
# Keep a pristine copy of original annotations that we never modify
# This ensures we can always access unvisited annotations
self.original_coco_annotations = self.coco_data['annotations'].copy()
all_images_unfiltered = self.coco_data['images']
if self.dataset_filter:
logger.info(f"Filtering images by dataset: {self.dataset_filter}")
self.all_images = [
img for img in all_images_unfiltered
if img.get('dataset', '').lower() == self.dataset_filter.lower()
]
if not self.all_images:
self.all_images = [
img for img in all_images_unfiltered
if self.dataset_filter.lower() in img.get('dataset', '').lower()
]
logger.info(f"Found {len(self.all_images)} images from {self.dataset_filter}")
else:
self.all_images = all_images_unfiltered
self.all_annotations = self.coco_data['annotations'].copy() # Make a copy to avoid reference issues
self.image_id_to_anns = defaultdict(list)
for ann in self.all_annotations:
self.image_id_to_anns[ann['image_id']].append(ann)
# Log statistics about existing annotations
total_anns = len(self.all_annotations)
pseudo_anns = sum(1 for ann in self.all_annotations if ann.get('is_pseudolabel', False))
logger.info(f"Loaded {total_anns} annotations ({pseudo_anns} pseudolabels)")
# Log distribution of pseudolabels across images for debugging
images_with_pseudo = set()
for ann in self.all_annotations:
if ann.get('is_pseudolabel', False):
images_with_pseudo.add(ann['image_id'])
if images_with_pseudo:
logger.info(f" Pseudolabels found in {len(images_with_pseudo)} images")
self.image_id_to_info = {img['id']: img for img in self.all_images}
logger.info("Initializing detector models...")
self.detector = PedestrianDetector(
model_paths=model_paths,
nms_thr=0.8
)
self.model_paths = model_paths
self.refine_score_threshold = 0.01
self.predict_score_threshold = 0.3
self.current_idx = 0
self.current_image = None
self.current_bboxes = []
self.original_bboxes = []
self.predicted_bboxes = []
self.similarity_scores = {}
self.using_existing_pseudolabels = False
# T press counter and threshold steps
self.t_press_count = 0
self.threshold_steps = [0.3, 0.25, 0.2, 0.15, 0.1, 0.05, 0.01, 0.005, 0.001]
self.processed_images = set() # Images that have been processed (ever)
self.session_visited_images = set() # Images visited in THIS session only
self.working_data = {
"info": self.coco_data.get("info", {}),
"licenses": self.coco_data.get("licenses", []),
"categories": self.coco_data.get("categories", []),
"images": self.coco_data['images'].copy(),
"annotations": [] # Start empty - will only contain annotations for visited images
}
self.annotation_id_counter = max([ann['id'] for ann in self.all_annotations], default=0) + 1
self.load_progress()
self.window_name = "Interactive Pseudolabeler"
self.mouse_x = 0
self.mouse_y = 0
self.show_original = True
self.show_predicted = True
self.iou_threshold = iou_threshold
self.hovered_bbox = None # Track which bbox is under mouse
self.auto_mode = False # Automatic processing mode
self.auto_predict_mode = False # Automatic predict and process mode
self.auto_mode_delay = 500 # Delay in ms between auto-processing images
def load_progress(self):
if self.progress_file.exists():
with open(self.progress_file, 'r') as f:
progress = json.load(f)
self.processed_images = set(progress.get('processed_images', []))
self.current_idx = progress.get('current_idx', 0)
logger.info(f"Loaded progress: {len(self.processed_images)} images processed")
def save_progress(self):
progress = {
'processed_images': list(self.processed_images),
'current_idx': self.current_idx,
'timestamp': time.time()
}
with open(self.progress_file, 'w') as f:
json.dump(progress, f, indent=2, cls=NumpyEncoder)
def reload_annotations_from_disk(self):
"""Reload annotations from disk to ensure consistency"""
logger.info("Reloading annotations from disk...")
with open(self.annotations_json, 'r') as f:
self.coco_data = json.load(f)
self.all_annotations = self.coco_data['annotations'].copy()
# Update the original annotations reference
self.original_coco_annotations = self.coco_data['annotations'].copy()
# Rebuild image_id_to_anns
self.image_id_to_anns.clear()
for ann in self.all_annotations:
self.image_id_to_anns[ann['image_id']].append(ann)
# Only update working_data annotations for images visited in THIS session
# This preserves the session's work while keeping unvisited images untouched
visited_anns = [
ann for ann in self.all_annotations
if ann['image_id'] in self.session_visited_images
]
self.working_data['annotations'] = visited_anns
logger.info(f"Reloaded {len(self.all_annotations)} annotations from disk")
logger.info(f" Working data contains {len(self.working_data['annotations'])} annotations for {len(self.session_visited_images)} session-visited images")
def validate_annotations(self):
"""Validate that annotations are consistent across data structures"""
issues = []
# Check if all annotations in working_data are in image_id_to_anns
working_ids = {(ann['id'], ann['image_id']) for ann in self.working_data['annotations']}
mapped_ids = set()
for img_id, anns in self.image_id_to_anns.items():
for ann in anns:
mapped_ids.add((ann['id'], ann['image_id']))
missing_in_map = working_ids - mapped_ids
if missing_in_map:
issues.append(f"Annotations in working_data but not in image_id_to_anns: {missing_in_map}")
# Check pseudolabel counts
pseudo_working = sum(1 for ann in self.working_data['annotations'] if ann.get('is_pseudolabel', False))
pseudo_all = sum(1 for ann in self.all_annotations if ann.get('is_pseudolabel', False))
logger.debug(f"Validation: {pseudo_working} pseudolabels in working_data, {pseudo_all} in all_annotations")
if issues:
logger.warning(f"Validation issues found: {issues}")
return len(issues) == 0
def calculate_similarity_score(self, predicted: List[BBox], original: List[BBox]) -> float:
if not predicted or not original:
return 0.0
scores = []
for pred in predicted:
best_iou = max([pred.iou(orig) for orig in original], default=0.0)
scores.append(best_iou * pred.score)
return np.mean(scores) if scores else 0.0
def refine_predictions(self, predicted: List[BBox], original: List[BBox], hw) -> List[BBox]:
if not predicted or not original:
return original.copy() if original else []
refined = []
used_predictions = set()
matched_gt_indices = set()
for gt_idx, gt_bbox in enumerate(original):
best_match = None
best_score = -10000
for i, pred_bbox in enumerate(predicted):
if i in used_predictions:
continue
iou = gt_bbox.iou(pred_bbox)
if iou >= 0.01:
overflow_area = pred_bbox.overflow_area(gt_bbox)
combined_score = iou + pred_bbox.score - overflow_area
if combined_score > best_score:
best_score = combined_score
best_match = (i, pred_bbox)
if best_match and best_match[1].score >= self.refine_score_threshold:
used_predictions.add(best_match[0])
matched_gt_indices.add(gt_idx)
refined.append(best_match[1])
logger.info(
f"Refined: GT bbox matched with prediction (conf={best_match[1].score:.3f}, iou={gt_bbox.iou(best_match[1]):.3f})")
# Keep unmatched GT bboxes
unmatched_gt = [gt for idx, gt in enumerate(original) if idx not in matched_gt_indices]
logger.info(f"Refinement: {len(original)} GT boxes -> {len(refined)} refined predictions + {len(unmatched_gt)} unmatched GT boxes")
# Return both refined predictions and unmatched GT bboxes
return refined + unmatched_gt
def load_image(self, idx: int, auto_predict: bool = True):
if idx < 0 or idx >= len(self.all_images):
return False
# Save current annotations before switching images (if we have a current image)
if self.current_idx >= 0 and self.current_idx < len(self.all_images) and self.current_image is not None:
self.save_current_annotations()
self.current_idx = idx
img_info = self.all_images[idx]
img_path = self.images_dir / img_info['file_name']
# Reset T press counter when switching to a new image
self.t_press_count = 0
if not img_path.exists():
logger.warning(f"Image not found: {img_path}")
return False
self.current_image = cv2.imread(str(img_path))
if self.current_image is None:
logger.warning(f"Failed to load image: {img_path}")
return False
self.original_bboxes = []
self.predicted_bboxes = []
existing_pseudolabels = []
# Check if this image has been processed before
# Use image_id_to_anns which should be kept in sync with saved data
current_anns = self.image_id_to_anns.get(img_info['id'], [])
# Log for debugging
pseudo_count = sum(1 for ann in current_anns if ann.get('is_pseudolabel', False))
if pseudo_count > 0:
logger.debug(f"Found {pseudo_count} existing pseudolabels for image {img_info['id']}")
for ann in current_anns:
is_pseudo = ann.get('is_pseudolabel', False)
if is_pseudo:
bbox = BBox.from_coco(
ann['bbox'],
score=ann.get('confidence', ann.get('score', 0.5)),
source="predicted",
id=ann['id'],
category_id=ann.get('category_id', 0),
area=ann.get('area')
)
existing_pseudolabels.append(bbox)
else:
bbox = BBox.from_coco(
ann['bbox'],
score=1.0,
source="original",
id=ann['id'],
category_id=ann.get('category_id', 0),
area=ann.get('area')
)
self.original_bboxes.append(bbox)
if existing_pseudolabels and not self.force_repredict:
logger.info(f"Image already has {len(existing_pseudolabels)} pseudolabels")
self.predicted_bboxes = existing_pseudolabels
self.using_existing_pseudolabels = True
elif auto_predict:
if existing_pseudolabels and self.force_repredict:
logger.info(f"Force re-predicting (ignoring {len(existing_pseudolabels)} existing pseudolabels)")
else:
logger.info(f"Generating predictions for {img_info['file_name']}...")
if self.refine_mode and self.original_bboxes:
predicted_bboxes = self.detector.predict(self.current_image, self.refine_score_threshold)
self.using_existing_pseudolabels = False
logger.info(f"Applying refinement mode: {len(predicted_bboxes)} predictions -> filtering...")
refined_results = self.refine_predictions(predicted_bboxes, self.original_bboxes, hw=self.current_image.shape[:2])
self.predicted_bboxes += [bbox for bbox in refined_results if bbox.source == "predicted"]
self.original_bboxes = [bbox for bbox in refined_results if bbox.source == "original"]
elif self.refine_mode and not self.original_bboxes:
logger.info(f"Skipping refinement mode: No GT bboxes available for refinement")
else:
logger.info(f"Skipping prediction for {img_info['file_name']} (manual mode)")
self.using_existing_pseudolabels = bool(existing_pseudolabels)
similarity = self.calculate_similarity_score(self.predicted_bboxes, self.original_bboxes)
self.similarity_scores[img_info['id']] = similarity
self.current_bboxes = self.original_bboxes + self.predicted_bboxes
logger.info(f"Image {idx + 1}/{len(self.all_images)}: "
f"{len(self.original_bboxes)} original, "
f"{len(self.predicted_bboxes)} predicted, "
f"similarity: {similarity:.3f}")
return True
def draw_bboxes(self, img):
"""Draw bounding boxes on image"""
vis_img = img.copy()
hovered_candidates = []
for bbox in self.current_bboxes:
x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2)
if x1 <= self.mouse_x <= x2 and y1 <= self.mouse_y <= y2:
area = (x2 - x1) * (y2 - y1)
hovered_candidates.append((area, bbox))
self.hovered_bbox = None
if hovered_candidates:
hovered_candidates.sort(key=lambda x: x[0])
self.hovered_bbox = hovered_candidates[0][1]
for bbox in self.current_bboxes:
if bbox.source == "original" and not self.show_original:
continue
if bbox.source == "predicted" and not self.show_predicted:
continue
x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2)
is_hovered = (bbox == self.hovered_bbox)
if bbox.source == "original":
color = (0, 255, 0) # Green for original
label = f"GT"
else:
# Color based on score
if bbox.score > 0.7:
color = (255, 0, 0) # Blue for high confidence
elif bbox.score > 0.5:
color = (0, 165, 255) # Orange for medium
else:
color = (0, 0, 255) # Red for low
label = f"{bbox.score:.2f}"
thickness = 3 if is_hovered else 2
if is_hovered:
color = (0, 255, 255) # Yellow for hovered
cv2.rectangle(vis_img, (x1 - 1, y1 - 1), (x2 + 1, y2 + 1), (0, 0, 0), thickness + 1)
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, thickness)
if is_hovered:
area = (x2 - x1) * (y2 - y1)
label = f"{label} [REMOVE] A:{area}"
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(vis_img, (x1, y1 - label_size[1] - 4),
(x1 + label_size[0], y1), color, -1)
cv2.putText(vis_img, label, (x1, y1 - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
if self.hovered_bbox:
cv2.drawMarker(vis_img, (self.mouse_x, self.mouse_y),
(0, 255, 255), cv2.MARKER_CROSS, 15, 2)
vis_img = self.draw_info_panel(vis_img)
return vis_img
def draw_info_panel(self, img):
h, w = img.shape[:2]
panel_height = 140
panel = np.zeros((panel_height, w, 3), dtype=np.uint8)
panel[:] = (40, 40, 40)
current_img = self.all_images[self.current_idx]
dataset_name = current_img.get('dataset', 'unknown')
# Info text
pseudo_status = "LOADED" if self.using_existing_pseudolabels else (
"REFINED" if self.refine_mode else "PREDICTED")
mode_indicator = " [REFINE MODE]" if self.refine_mode else ""
if self.auto_mode:
mode_indicator += " [AUTO MODE ACTIVE]"
if self.auto_predict_mode:
mode_indicator += " [AUTO PREDICT MODE ACTIVE]"
# Calculate current threshold for display
threshold_index = min(self.t_press_count, len(self.threshold_steps) - 1)
current_threshold = self.threshold_steps[threshold_index]
info_lines = [
f"Dataset: {dataset_name} | File: {current_img.get('file_name', 'unknown')}{mode_indicator}",
f"Image {self.current_idx + 1}/{len(self.all_images)} | "
f"Original: {len(self.original_bboxes)} | "
f"Pseudo [{pseudo_status}]: {len(self.predicted_bboxes)} | "
f"Current: {len(self.current_bboxes)}",
f"Similarity: {self.similarity_scores.get(self.all_images[self.current_idx]['id'], 0):.3f} | "
f"Processed: {len(self.processed_images)} | "
f"T presses: {self.t_press_count} | Threshold: {current_threshold:.3f}",
"Controls: Click=Remove | O=Toggle Original | P=Toggle Predicted | T=Trigger Predict | "
"R=Remove All Predicted | G=Remove All Original | N=No GT & Next | M=Predict,No GT & Next",
"Navigation: A/D=Prev/Next (no predict) | Space=Next (auto-predict) | W=Write to Disk | Q=Quit | J=Jump",
"Auto Modes: Z=Toggle Auto Mode (use existing) | X=Toggle Auto Predict Mode (force predict)"
]
y_offset = 20
for line in info_lines:
cv2.putText(panel, line, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
y_offset += 25
# Add auto mode indicator
if self.auto_mode or self.auto_predict_mode:
# Add flashing border to indicate auto mode
h_total, w_total = combined.shape[:2] if 'combined' in locals() else (h + panel_height, w)
if self.auto_predict_mode:
border_color = (255, 0, 255) if (time.time() * 2) % 2 < 1 else (255, 100, 255) # Magenta for predict mode
auto_text = "AUTO PREDICT MODE"
else:
border_color = (0, 0, 255) if (time.time() * 2) % 2 < 1 else (0, 100, 255) # Red for normal auto
auto_text = "AUTO MODE ACTIVE"
cv2.rectangle(panel, (0, 0), (w-1, panel_height-1), border_color, 3)
# Add large mode text
font_scale = 1.0
thickness = 2
text_size, _ = cv2.getTextSize(auto_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
text_x = w - text_size[0] - 20
text_y = 30
cv2.putText(panel, auto_text, (text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, border_color, thickness)
combined = np.vstack([img, panel])
return combined
def mouse_callback(self, event, x, y, flags, param):
"""Handle mouse events"""
self.mouse_x = x
self.mouse_y = y
# Update display on mouse move to show hover effect
if event == cv2.EVENT_MOUSEMOVE:
pass # The draw_bboxes function will handle the hover effect
if event == cv2.EVENT_LBUTTONDOWN:
# Find all bboxes that contain the clicked point
clicked_bboxes = []
for i, bbox in enumerate(self.current_bboxes):
x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2)
if x1 <= x <= x2 and y1 <= y <= y2:
# Calculate area
area = (x2 - x1) * (y2 - y1)
clicked_bboxes.append((area, i, bbox))
if clicked_bboxes:
# Sort by area (smallest first) and remove the smallest one
clicked_bboxes.sort(key=lambda x: x[0])
area, idx, bbox_to_remove = clicked_bboxes[0]
# Remove the bbox with smallest area
removed = self.current_bboxes.pop(idx)
logger.info(f"Removed {removed.source} bbox with score {removed.score:.3f} (area: {area})")
# Also remove from source lists
if removed.source == "original" and removed in self.original_bboxes:
self.original_bboxes.remove(removed)
elif removed.source == "predicted" and removed in self.predicted_bboxes:
self.predicted_bboxes.remove(removed)
# Auto-save to memory after modification
self.save_current_annotations()
def save_current_annotations(self):
img_info = self.all_images[self.current_idx]
self.processed_images.add(img_info['id'])
self.session_visited_images.add(img_info['id']) # Track that we visited this in current session
# Clear existing annotations for this image from working_data only
# working_data should only contain annotations for visited images
self.working_data['annotations'] = [
ann for ann in self.working_data['annotations']
if ann['image_id'] != img_info['id']
]
# Update image_id_to_anns for this specific image
self.image_id_to_anns[img_info['id']] = []
# Update all_annotations - remove old annotations for this image
# This is safe because we're only updating visited images
self.all_annotations = [
ann for ann in self.all_annotations
if ann['image_id'] != img_info['id']
]
for bbox in self.current_bboxes:
bbox_coco = bbox.to_coco()
bbox_coco = [float(x) for x in bbox_coco]
ann = {
'id': int(bbox.id) if bbox.id else self.annotation_id_counter,
'image_id': int(img_info['id']),
'category_id': int(bbox.category_id),
'bbox': bbox_coco,
'area': float(bbox.area) if bbox.area else float((bbox.x2 - bbox.x1) * (bbox.y2 - bbox.y1)),
'segmentation': [],
'iscrowd': 0,
'is_pseudolabel': bbox.source == 'predicted', # Mark pseudolabels
'confidence': float(bbox.score) if bbox.source == 'predicted' else 1.0, # Store confidence
'verified': True
}
if bbox.source == 'original' and bbox.id:
# Find original annotation to preserve all fields
for orig_ann in self.coco_data['annotations']: # Use original coco_data
if orig_ann['id'] == bbox.id:
for key, value in orig_ann.items():
if key not in ann:
ann[key] = value
break
self.working_data['annotations'].append(ann)
# Update all data structures to ensure persistence
self.image_id_to_anns[img_info['id']].append(ann)
self.all_annotations.append(ann)
if not bbox.id:
bbox.id = self.annotation_id_counter # Assign ID to bbox for consistency
self.annotation_id_counter += 1
n_original = sum(1 for b in self.current_bboxes if b.source == 'original')
n_pseudo = sum(1 for b in self.current_bboxes if b.source == 'predicted')
logger.info(f"Saved {len(self.current_bboxes)} annotations for image {img_info['id']} "
f"(original: {n_original}, pseudo: {n_pseudo})")
# Log detailed info for debugging
logger.debug(f" working_data has {len(self.working_data['annotations'])} total annotations")
logger.debug(f" image_id_to_anns[{img_info['id']}] has {len(self.image_id_to_anns[img_info['id']])} annotations")
def write_to_disk(self):
"""Write updated annotations back to the original file"""
# Merge processed annotations with unprocessed ones
final_data = self.coco_data.copy()
# CRITICAL: Only update images visited in THIS session, not all processed images
# This prevents losing annotations for images processed in previous sessions
visited_img_ids = self.session_visited_images
# Keep all annotations for images that weren't visited in THIS session
# Use the ORIGINAL annotations, not the potentially modified coco_data
# This includes existing pseudolabels from previous sessions
unvisited_annotations = [
ann for ann in self.original_coco_annotations
if ann['image_id'] not in visited_img_ids
]
# Get annotations for visited images from working_data
# These are the updated annotations from the current session
visited_annotations = [
ann for ann in self.working_data['annotations']
if ann['image_id'] in visited_img_ids
]
# Combine annotations: unvisited (preserved) + visited (updated)
final_data['annotations'] = unvisited_annotations + visited_annotations
# Calculate statistics
total_anns = len(final_data['annotations'])
pseudo_anns = sum(1 for ann in final_data['annotations']
if ann.get('is_pseudolabel', False))
original_anns = total_anns - pseudo_anns
# Calculate preserved vs updated statistics
preserved_anns = len(unvisited_annotations)
preserved_pseudo = sum(1 for ann in unvisited_annotations if ann.get('is_pseudolabel', False))
updated_anns = len(visited_annotations)
updated_pseudo = sum(1 for ann in visited_annotations if ann.get('is_pseudolabel', False))
# Debug logging
logger.debug(f"Write to disk debug:")
logger.debug(f" Original annotations count: {len(self.original_coco_annotations)}")
logger.debug(f" Session visited image IDs: {len(visited_img_ids)} images")
logger.debug(f" Total processed images (all sessions): {len(self.processed_images)} images")
logger.debug(f" Unvisited annotations to preserve: {preserved_anns}")
logger.debug(f" Visited annotations to update: {updated_anns}")
# Add metadata about pseudolabeling
if 'info' not in final_data:
final_data['info'] = {}
final_data['info']['pseudolabeling'] = {
'last_updated': time.strftime('%Y-%m-%d %H:%M:%S'),
'total_annotations': total_anns,
'original_annotations': original_anns,
'pseudolabeled_annotations': pseudo_anns,
'images_processed_total': len(self.processed_images),
'images_processed_this_session': len(self.session_visited_images),
'models_used': [os.path.basename(m) for m in self.model_paths],
'refine_mode': self.refine_mode,
'iou_threshold': self.iou_threshold if self.refine_mode else None
}
# Save to original file
with open(self.annotations_json, 'w') as f:
json.dump(final_data, f, indent=2, cls=NumpyEncoder)
# CRITICAL FIX: Update internal data structures to reflect saved state
self.coco_data['annotations'] = final_data['annotations'].copy()
self.all_annotations = final_data['annotations'].copy()
# Update the original annotations to reflect the saved state
# This becomes the new baseline for future saves
self.original_coco_annotations = final_data['annotations'].copy()
# Rebuild image_id_to_anns with updated annotations
self.image_id_to_anns.clear()
for ann in self.all_annotations:
self.image_id_to_anns[ann['image_id']].append(ann)
logger.info(f"Updated annotations saved to {self.annotations_json}")
logger.info(f" Total: {total_anns} annotations")
logger.info(f" Original: {original_anns} annotations")
logger.info(f" Pseudolabeled: {pseudo_anns} annotations")
logger.info(f" Images visited in this session: {len(self.session_visited_images)}")
logger.info(f" Total images processed (all sessions): {len(self.processed_images)}")
logger.info(f" Preserved annotations (unvisited in this session): {preserved_anns} ({preserved_pseudo} pseudolabels)")
logger.info(f" Updated annotations (visited in this session): {updated_anns} ({updated_pseudo} pseudolabels)")
# Save progress
self.save_progress()
def trigger_prediction(self):
"""Manually trigger prediction for current image"""
img_info = self.all_images[self.current_idx]
# Calculate dynamic threshold based on T press count
threshold_index = min(self.t_press_count, len(self.threshold_steps) - 1)
current_threshold = self.threshold_steps[threshold_index]
logger.info(f"Manually triggering prediction for {img_info['file_name']}...")
logger.info(f"T press #{self.t_press_count + 1}, using threshold: {current_threshold:.3f}")
# Generate predictions with dynamic threshold
self.predicted_bboxes = self.detector.predict(self.current_image, current_threshold)
self.using_existing_pseudolabels = False
# Increment T press counter after prediction
self.t_press_count += 1
# Apply refinement if in refine mode and GT bboxes exist
if self.refine_mode and self.original_bboxes:
logger.info(f"Applying refinement mode: {len(self.predicted_bboxes)} predictions -> filtering...")
refined_results = self.refine_predictions(self.predicted_bboxes, self.original_bboxes, hw=self.current_image.shape[:2])
# Separate refined predictions from unmatched GT bboxes
self.predicted_bboxes = [bbox for bbox in refined_results if bbox.source == "predicted"]
# Update original_bboxes to only contain unmatched GT bboxes
self.original_bboxes = [bbox for bbox in refined_results if bbox.source == "original"]
elif self.refine_mode and not self.original_bboxes:
logger.info(f"Skipping refinement: No GT bboxes available")
# Recalculate similarity score
similarity = self.calculate_similarity_score(self.predicted_bboxes, self.original_bboxes)
self.similarity_scores[img_info['id']] = similarity
# Update current bboxes
self.current_bboxes = self.original_bboxes + self.predicted_bboxes
logger.info(f"Prediction complete: {len(self.predicted_bboxes)} predictions generated")
# Auto-save after triggering prediction
self.save_current_annotations()
def jump_to_dataset(self):
"""Jump to a specific dataset"""
# Get unique dataset names
datasets = set()
for img in self.all_images:
if 'dataset' in img:
datasets.add(img['dataset'])
if not datasets:
logger.warning("No dataset information found in images")
return
datasets = sorted(list(datasets))
# Show available datasets
print("\nAvailable datasets:")
for i, ds in enumerate(datasets):
count = sum(1 for img in self.all_images if img.get('dataset') == ds)
print(f"{i + 1}. {ds} ({count} images)")
# Get user input
try:
choice = input("Enter dataset number (or name): ").strip()
# Try to parse as number
if choice.isdigit():
idx = int(choice) - 1
if 0 <= idx < len(datasets):
target_dataset = datasets[idx]
else:
logger.warning("Invalid dataset number")
return
else:
# Use as dataset name
target_dataset = choice
# Find first image from this dataset
for i, img in enumerate(self.all_images):
if img.get('dataset', '').lower() == target_dataset.lower():
self.load_image(i, auto_predict=False)
logger.info(f"Jumped to dataset: {target_dataset}")
return
logger.warning(f"Dataset not found: {target_dataset}")
except (ValueError, EOFError, KeyboardInterrupt):
logger.info("Jump cancelled")
def run(self):
"""Main loop"""
cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
cv2.setMouseCallback(self.window_name, self.mouse_callback)
# Load first image
if not self.all_images:
logger.error("No images to process!")
return
if not self.load_image(self.current_idx, auto_predict=False):
logger.error("Failed to load first image")
return
while True:
# Draw current state
vis_img = self.draw_bboxes(self.current_image)
cv2.imshow(self.window_name, vis_img)
# Handle keyboard input
# Use shorter wait time in auto mode for responsiveness
wait_time = self.auto_mode_delay if (self.auto_mode or self.auto_predict_mode) else 1
key = cv2.waitKey(wait_time) & 0xFF
# Auto predict mode processing (force prediction)
if self.auto_predict_mode and key == 255: # No key pressed
logger.info(f"[AUTO PREDICT MODE] Processing image {self.current_idx + 1}/{len(self.all_images)}")
# Force predict bboxes
self.predicted_bboxes = self.detector.predict(self.current_image, self.predict_score_threshold)
self.using_existing_pseudolabels = False
# Apply refinement if in refine mode and GT bboxes exist
if self.refine_mode and self.original_bboxes:
logger.info(f"[AUTO PREDICT MODE] Applying refinement with {len(self.original_bboxes)} GT bboxes")
refined_results = self.refine_predictions(self.predicted_bboxes, self.original_bboxes, hw=vis_img.shape[:2])
self.predicted_bboxes = [bbox for bbox in refined_results if bbox.source == "predicted"]
elif self.refine_mode and not self.original_bboxes:
logger.info(f"[AUTO PREDICT MODE] No GT bboxes - using raw predictions without refinement")
# Remove all GT bboxes, keep only predictions
self.current_bboxes = self.predicted_bboxes.copy()
self.original_bboxes = []
# Save current annotations
self.save_current_annotations()
# Write to disk periodically (every 10 images)
if (self.current_idx + 1) % 10 == 0:
logger.info("[AUTO PREDICT MODE] Auto-saving to disk (every 10 images)")
self.write_to_disk()
# Move to next image
if self.current_idx < len(self.all_images) - 1:
self.load_image(self.current_idx + 1, auto_predict=False) # Load without auto-predict since we'll predict manually
else:
# Reached the end, disable auto mode
self.auto_predict_mode = False
logger.info("[AUTO PREDICT MODE] Reached last image, auto mode disabled")
self.write_to_disk() # Final save
continue # Skip the rest of the loop to process next image
# Auto mode processing (use existing predictions or refine)
elif self.auto_mode and key == 255: # No key pressed
# Perform automatic processing: remove GT, keep predictions, save, and move to next
logger.info(f"[AUTO MODE] Processing image {self.current_idx + 1}/{len(self.all_images)}")
# Save current annotations
self.save_current_annotations()
# Write to disk periodically (every 10 images)
if (self.current_idx + 1) % 10 == 0:
logger.info("[AUTO MODE] Auto-saving to disk (every 10 images)")
self.write_to_disk()
# Move to next image
if self.current_idx < len(self.all_images) - 1:
self.load_image(self.current_idx + 1, auto_predict=True)
else:
# Reached the end, disable auto mode
self.auto_mode = False
logger.info("[AUTO MODE] Reached last image, auto mode disabled")
self.write_to_disk() # Final save
continue # Skip the rest of the loop to process next image
# If any key is pressed during auto mode (except 255 which means no key), handle it
if self.auto_mode and key != 255 and key != ord('z'):
# Disable auto mode if any other key is pressed
self.auto_mode = False
logger.info("[AUTO MODE] Interrupted by user input, auto mode disabled")
self.write_to_disk() # Save progress
# If any key is pressed during auto predict mode (except 255 which means no key), handle it
if self.auto_predict_mode and key != 255 and key != ord('x'):
# Disable auto predict mode if any other key is pressed
self.auto_predict_mode = False
logger.info("[AUTO PREDICT MODE] Interrupted by user input, auto predict mode disabled")
self.write_to_disk() # Save progress
if key == ord('q'):
# Quit
if self.auto_mode or self.auto_predict_mode:
self.auto_mode = False
self.auto_predict_mode = False
self.write_to_disk()
break
elif key == ord('d'):
if self.current_idx < len(self.all_images) - 1:
self.load_image(self.current_idx + 1, auto_predict=False)
elif key == ord('a'):
if self.current_idx > 0:
self.load_image(self.current_idx - 1, auto_predict=False)
elif key == ord('s'):
# Manual save (though auto-save is enabled)
self.save_current_annotations()
logger.info("Annotations saved to memory (auto-save is enabled)")
elif key == ord('w'):
# Write to disk
self.save_current_annotations() # Save current image first
self.write_to_disk()
logger.info("Annotations written to disk")
# Validate after writing
if not self.validate_annotations():
logger.warning("Data consistency issues detected after write")
# Reload from disk to ensure consistency
self.reload_annotations_from_disk()
elif key == ord('o'):
# Toggle original bboxes
self.show_original = not self.show_original
logger.info(f"Original bboxes: {'shown' if self.show_original else 'hidden'}")
elif key == ord('p'):
# Toggle predicted bboxes
self.show_predicted = not self.show_predicted
logger.info(f"Predicted bboxes: {'shown' if self.show_predicted else 'hidden'}")
elif key == ord('r'):
# Remove all predicted bboxes
self.current_bboxes = [b for b in self.current_bboxes if b.source != "predicted"]
self.predicted_bboxes = []
logger.info("Removed all predicted bboxes")
# Auto-save after modification
self.save_current_annotations()
elif key == ord('g'):
# Remove all original bboxes
self.current_bboxes = [b for b in self.current_bboxes if b.source != "original"]
self.original_bboxes = []
logger.info("Removed all original bboxes")
# Auto-save after modification
self.save_current_annotations()
elif key == ord('j'):
# Jump to dataset
self.jump_to_dataset()
elif key == ord('t'):
# Trigger prediction manually
self.trigger_prediction()
elif key == ord(' '):
# Space - quick save and next with auto-predict
self.save_current_annotations()
if self.current_idx < len(self.all_images) - 1:
self.load_image(self.current_idx + 1, auto_predict=True)
elif key == ord('f') and self.refine_mode:
if self.original_bboxes:
logger.info("Re-running refinement on current image...")
all_predictions = self.detector.predict(self.current_image, self.refine_score_threshold)
refined_results = self.refine_predictions(all_predictions, self.original_bboxes,
hw=vis_img.shape[:2])
self.predicted_bboxes += [bbox for bbox in refined_results if bbox.source == "predicted"]
self.original_bboxes = [bbox for bbox in refined_results if bbox.source == "original"]
self.current_bboxes = self.original_bboxes + self.predicted_bboxes
logger.info(
f"Refinement complete: {len(all_predictions)} predictions -> {len(self.predicted_bboxes)} refined")
# Auto-save after refinement
self.save_current_annotations()
else:
logger.info("Cannot run refinement: No GT bboxes available")
elif key == ord('n'):
logger.info("Removing all GT bboxes, saving, and moving to next image...")
self.save_current_annotations()
self.write_to_disk()
if self.current_idx < len(self.all_images) - 1:
self.load_image(self.current_idx + 1, auto_predict=True)
else:
logger.info("Already at last image")
elif key == ord('m'):
# M - Force predict, remove GT, and move to next
logger.info("Predicting bboxes, removing all GT bboxes, and moving to next image...")
# Calculate dynamic threshold based on T press count
threshold_index = min(self.t_press_count, len(self.threshold_steps) - 1)
current_threshold = self.threshold_steps[threshold_index]
# Force predict bboxes (not refine) with dynamic threshold
self.predicted_bboxes = self.detector.predict(self.current_image, current_threshold)
self.using_existing_pseudolabels = False
# Increment T press counter after prediction
self.t_press_count += 1
# Remove all GT bboxes, keep only predictions
self.current_bboxes = self.predicted_bboxes.copy()
self.original_bboxes = []
# Save and write to disk
self.save_current_annotations()
self.write_to_disk()
# Move to next image
if self.current_idx < len(self.all_images) - 1:
self.load_image(self.current_idx + 1, auto_predict=False)
else:
logger.info("Already at last image")
elif key == ord('z'):
# Toggle auto mode (uses existing predictions or refines)
self.auto_mode = not self.auto_mode
if self.auto_mode:
# Disable auto predict mode if it's on
if self.auto_predict_mode:
self.auto_predict_mode = False
logger.info(f"[AUTO MODE] Enabled - will process images automatically (delay: {self.auto_mode_delay}ms)")
logger.info("[AUTO MODE] Press 'Z' again to stop, or any other key to interrupt")
# Ensure we have predictions for current image
if not self.predicted_bboxes:
self.trigger_prediction()
else:
logger.info("[AUTO MODE] Disabled")
# Save when exiting auto mode
self.save_current_annotations()
self.write_to_disk()
elif key == ord('x'):
# Toggle auto predict mode (forces new predictions)
self.auto_predict_mode = not self.auto_predict_mode
if self.auto_predict_mode:
# Disable normal auto mode if it's on
if self.auto_mode:
self.auto_mode = False
logger.info(f"[AUTO PREDICT MODE] Enabled - will predict and process images automatically (delay: {self.auto_mode_delay}ms)")
logger.info("[AUTO PREDICT MODE] Press 'X' again to stop, or any other key to interrupt")
logger.info("[AUTO PREDICT MODE] This mode forces new predictions for each image")
else:
logger.info("[AUTO PREDICT MODE] Disabled")
# Save when exiting auto mode
self.save_current_annotations()
self.write_to_disk()
cv2.destroyAllWindows()
# Final save
if len(self.processed_images) > 0:
response = input("\nSave all annotations to disk? (y/n): ")
if response.lower() == 'y':
self.write_to_disk()
def main():
parser = argparse.ArgumentParser(description="Interactive Pseudolabeling Tool")
parser.add_argument("--images", type=str,
default="/mnt/archive/person_drone/wisard_coco",
help="Path to images directory")
parser.add_argument("--annotations", type=str,
default="/mnt/archive/person_drone/wisard_coco/annotations.json",
help="Path to COCO annotations JSON file (will be updated in-place)")
parser.add_argument("--models", nargs="+", type=str,
default=[
"model_deimhgnetV2m_cpu_v0.pt",
"mmpedestron_onnx_mix_traced.pt",
"mmpedestron_onnx_v2_traced.pt",
],
help="Paths to traced models")
parser.add_argument("--iou-thr", type=float, default=0.8,
help="IoU threshold for matching predictions to GT")
parser.add_argument("--filter-dataset", type=str, default=None,
help="Filter images by dataset name (e.g., 'visdrone2019', 'stanford_drone')")
parser.add_argument("--force-repredict", action="store_true",
help="Force re-prediction even if pseudolabels already exist")
parser.add_argument("--refine", action="store_true",
help="Enable GT refinement mode: predict with low confidence and keep only best matches with GT boxes")
parser.add_argument("--auto-delay", type=int, default=200,
help="Delay in milliseconds between images in auto mode (default: 50ms)")
args = parser.parse_args()
model_paths = []
for model_path in args.models:
if not os.path.isabs(model_path):
model_path = os.path.join(os.path.dirname(__file__), model_path)
model_paths.append(model_path)
labeler = InteractivePseudolabeler(
images_path=args.images,
annotations_json=args.annotations,
model_paths=model_paths,
iou_threshold=args.iou_thr,
dataset_filter=args.filter_dataset,
force_repredict=args.force_repredict,
refine_mode=args.refine
)
# Set auto mode delay from command line
labeler.auto_mode_delay = args.auto_delay
# Run interactive session
labeler.run()
if __name__ == "__main__":
main()