| """ |
| Find best conf_threshold and gap_threshold for Jina and Nomic using COCO ground truth. |
| |
| Expects full_frames/ with images and annotations.coco.json (COCO format). |
| Runs the same detection + crop pipeline, matches each crop to a GT annotation (IoU), |
| then grid-searches (conf_threshold, gap_threshold) to maximize accuracy. |
| """ |
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers import AutoImageProcessor, DFineForObjectDetection |
|
|
| from dfine_jina_pipeline import ( |
| box_center_inside, |
| box_iou, |
| deduplicate_by_iou, |
| get_person_car_label_ids, |
| group_detections, |
| run_dfine, |
| squarify_crop_box, |
| ) |
| from jina_fewshot import IMAGE_EXTS, TRUNCATE_DIM, JinaCLIPv2Encoder, build_refs, draw_label_on_image |
| from nomic_fewshot import NomicTextEncoder, NomicVisionEncoder, build_refs_nomic |
|
|
| |
| CLASS_NAMES = ["cigarette", "gun", "knife", "phone"] |
|
|
|
|
| def coco_bbox_to_xyxy(bbox): |
| """COCO bbox [x, y, w, h] -> [x1, y1, x2, y2]. Tolerate string numbers from JSON.""" |
| x, y, w, h = (float(v) for v in bbox) |
| return [x, y, x + w, y + h] |
|
|
|
|
| def map_category_to_class(name: str) -> str | None: |
| """Map COCO category name to one of our 4 classes, or None if other.""" |
| n = (name or "").strip().lower() |
| if "cigarette" in n: |
| return "cigarette" |
| if any(x in n for x in ("gun", "pistol", "handgun", "firearm")): |
| return "gun" |
| if "knife" in n or "blade" in n: |
| return "knife" |
| if any(x in n for x in ("phone", "cell", "mobile", "smartphone", "telephone")): |
| return "phone" |
| return None |
|
|
|
|
| def load_coco_gt(annotations_path: Path): |
| """ |
| Load COCO JSON. Returns: |
| - file_to_gts: dict[file_name] = list of (bbox_xyxy, category_name) |
| - categories: list of category dicts from COCO |
| """ |
| with open(annotations_path) as f: |
| data = json.load(f) |
| images = {im["id"]: im for im in data.get("images", [])} |
| categories = {c["id"]: c["name"] for c in data.get("categories", [])} |
| file_to_gts = {} |
| for im in images.values(): |
| file_to_gts[im["file_name"]] = [] |
| for ann in data.get("annotations", []): |
| image_id = ann["image_id"] |
| cat_name = categories.get(ann["category_id"], "") |
| bbox_xyxy = coco_bbox_to_xyxy(ann["bbox"]) |
| file_name = images[image_id]["file_name"] |
| file_to_gts[file_name].append((bbox_xyxy, cat_name)) |
| |
| by_basename = {} |
| for fn, gts in file_to_gts.items(): |
| by_basename[Path(fn).name] = gts |
| return by_basename, data.get("categories", []) |
|
|
|
|
| def assign_gt_to_crop(crop_box_xyxy, gt_list, iou_min=0.3): |
| """ |
| Find best overlapping GT for this crop. Returns (gt_class or None, iou). |
| gt_class is one of CLASS_NAMES (mapped from category). |
| """ |
| best_iou = 0.0 |
| best_class = None |
| for bbox_xyxy, cat_name in gt_list: |
| iou = box_iou(crop_box_xyxy, bbox_xyxy) |
| if iou >= iou_min and iou > best_iou: |
| cls = map_category_to_class(cat_name) |
| if cls is not None: |
| best_iou = iou |
| best_class = cls |
| return best_class, best_iou |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Tune Jina/Nomic thresholds using COCO GT") |
| p.add_argument("--input", default="full_frames", help="Folder with images and annotations.coco.json") |
| p.add_argument("--annotations", default=None, help="Path to annotations.coco.json (default: input/_annotations.coco.json)") |
| p.add_argument("--refs", required=True, help="Reference images folder (for Jina + Nomic refs)") |
| p.add_argument("--output", default="threshold_tuning", help="Output folder for results CSV") |
| p.add_argument("--det-threshold", type=float, default=0.3) |
| p.add_argument("--group-dist", type=float, default=None) |
| p.add_argument("--expand", type=float, default=0.3) |
| p.add_argument("--min-side", type=int, default=40) |
| p.add_argument("--text-weight", type=float, default=0.3) |
| p.add_argument("--iou-min", type=float, default=0.3, help="Min IoU to match crop to GT") |
| p.add_argument("--crop-dedup-iou", type=float, default=0.35, help="Min IoU to treat two crops as same object (keep larger)") |
| p.add_argument("--no-squarify", action="store_true", help="Skip squarify; use expanded bbox only (tighter crops, often better recognition)") |
| p.add_argument("--max-images", type=int, default=None) |
| p.add_argument("--device", default=None) |
| p.add_argument("--no-save-crops", action="store_true", help="Do not save annotated crop images") |
| p.add_argument("--save-conf", type=float, default=0.5, help="Conf threshold for saved crop labels") |
| p.add_argument("--save-gap", type=float, default=0.02, help="Gap threshold for saved crop labels") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| input_dir = Path(args.input) |
| refs_dir = Path(args.refs) |
| output_dir = Path(args.output) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| annotations_path = Path(args.annotations) if args.annotations else input_dir / "_annotations.coco.json" |
| if not annotations_path.is_file(): |
| raise SystemExit(f"Annotations not found: {annotations_path}") |
|
|
| file_to_gts, _ = load_coco_gt(annotations_path) |
| print(f"[*] Loaded GT for {len(file_to_gts)} images from {annotations_path}") |
|
|
| paths = sorted(p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTS) |
| if args.max_images is not None: |
| paths = paths[: args.max_images] |
| |
| paths = [p for p in paths if p.name in file_to_gts] |
| if not paths: |
| raise SystemExit("No images in input that have COCO annotations.") |
|
|
| print(f"[*] Loading D-FINE...") |
| image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-medium-obj365") |
| dfine_model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-medium-obj365") |
| dfine_model = dfine_model.to(device).eval() |
| person_car_ids = get_person_car_label_ids(dfine_model) |
|
|
| print("[*] Loading Jina-CLIP-v2 and building refs...") |
| jina_encoder = JinaCLIPv2Encoder(device) |
| ref_labels, ref_embs = build_refs( |
| jina_encoder, refs_dir, TRUNCATE_DIM, args.text_weight, batch_size=16 |
| ) |
| assert ref_labels == CLASS_NAMES, f"Ref order {ref_labels}" |
|
|
| print("[*] Loading Nomic (vision + text) and building refs (same as Jina: text_weight 0.3)...") |
| nomic_encoder = NomicVisionEncoder(device) |
| nomic_text_encoder = NomicTextEncoder(device) |
| ref_labels_nomic, ref_embs_nomic = build_refs_nomic( |
| nomic_encoder, refs_dir, batch_size=16, |
| text_encoder=nomic_text_encoder, text_weight=args.text_weight, |
| ) |
|
|
| |
| save_crops = not args.no_save_crops |
| if save_crops: |
| jina_crops_dir = output_dir / "jina_crops" |
| nomic_crops_dir = output_dir / "nomic_crops" |
| crops_no_label_dir = output_dir / "crops" |
| detection_crops_dir = output_dir / "detection_crops" |
| jina_crops_dir.mkdir(parents=True, exist_ok=True) |
| nomic_crops_dir.mkdir(parents=True, exist_ok=True) |
| crops_no_label_dir.mkdir(parents=True, exist_ok=True) |
| detection_crops_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| rows = [] |
| for img_path in paths: |
| pil = Image.open(img_path).convert("RGB") |
| img_w, img_h = pil.size |
| group_dist = args.group_dist if args.group_dist is not None else 0.1 * max(img_h, img_w) |
| detections = run_dfine(pil, image_processor, dfine_model, device, args.det_threshold) |
| person_car = [d for d in detections if d["cls"] in person_car_ids] |
| if not person_car: |
| continue |
| grouped = group_detections(person_car, group_dist) |
| grouped.sort(key=lambda x: x["conf"], reverse=True) |
| gt_list = file_to_gts.get(img_path.name, []) |
| if not gt_list: |
| continue |
|
|
| |
| |
| candidates = [] |
| for gidx, grp in enumerate(grouped[:10]): |
| x1, y1, x2, y2 = grp["box"] |
| group_box = [x1, y1, x2, y2] |
| |
| if save_crops: |
| gx1 = max(0, int(x1)) |
| gy1 = max(0, int(y1)) |
| gx2 = min(img_w, int(x2)) |
| gy2 = min(img_h, int(y2)) |
| if gx2 > gx1 and gy2 > gy1: |
| group_crop = pil.crop((gx1, gy1, gx2, gy2)) |
| group_crop.save(detection_crops_dir / f"{img_path.stem}_group{gidx}.jpg") |
| inside = [ |
| d for d in detections |
| if box_center_inside(d["box"], group_box) |
| and d["cls"] not in person_car_ids |
| ] |
| inside = deduplicate_by_iou(inside, iou_threshold=0.9) |
|
|
| for crop_idx, d in enumerate(inside): |
| bx1, by1, bx2, by2 = [float(x) for x in d["box"]] |
| obj_w, obj_h = bx2 - bx1, by2 - by1 |
| if obj_w <= 0 or obj_h <= 0: |
| continue |
| pad_x, pad_y = obj_w * args.expand, obj_h * args.expand |
| bx1 = max(0, int(bx1 - pad_x)) |
| by1 = max(0, int(by1 - pad_y)) |
| bx2 = min(img_w, int(bx2 + pad_x)) |
| by2 = min(img_h, int(by2 + pad_y)) |
| if bx2 <= bx1 or by2 <= by1: |
| continue |
| if min(bx2 - bx1, by2 - by1) < args.min_side: |
| continue |
| expanded_box = [bx1, by1, bx2, by2] |
| gt_class, _ = assign_gt_to_crop(expanded_box, gt_list, args.iou_min) |
| if gt_class is None: |
| continue |
| candidates.append((expanded_box, gt_class, gidx, crop_idx)) |
|
|
| |
| def crop_area(box): |
| return (box[2] - box[0]) * (box[3] - box[1]) |
|
|
| candidates.sort(key=lambda c: -crop_area(c[0])) |
| kept = [] |
| for c in candidates: |
| expanded_box = c[0] |
| |
| def is_same_object(box_a, box_b): |
| if box_iou(box_a, box_b) >= args.crop_dedup_iou: |
| return True |
| if box_center_inside(box_a, box_b) or box_center_inside(box_b, box_a): |
| return True |
| return False |
| if not any(is_same_object(expanded_box, k[0]) for k in kept): |
| kept.append(c) |
|
|
| |
| for i, (expanded_box, gt_class, gidx, crop_idx) in enumerate(kept): |
| if not args.no_squarify: |
| bx1, by1, bx2, by2 = squarify_crop_box( |
| expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3], img_w, img_h |
| ) |
| else: |
| bx1, by1, bx2, by2 = expanded_box[0], expanded_box[1], expanded_box[2], expanded_box[3] |
| crop_box = [bx1, by1, bx2, by2] |
| crop_pil = pil.crop((bx1, by1, bx2, by2)) |
| bbox_suffix = f"_{bx1}_{by1}_{bx2}_{by2}" |
| crop_name = f"{img_path.stem}_g{gidx}_{i}{bbox_suffix}.jpg" |
| q_jina = jina_encoder.encode_images([crop_pil], TRUNCATE_DIM) |
| sims_jina = (q_jina @ ref_embs.T).squeeze(0) |
| best_jina = int(np.argmax(sims_jina)) |
| conf_jina = float(sims_jina[best_jina]) |
| gap_jina = float(sims_jina[best_jina] - np.partition(sims_jina, -2)[-2]) |
|
|
| q_nomic = nomic_encoder.encode_images([crop_pil]) |
| sims_nomic = (q_nomic @ ref_embs_nomic.T).squeeze(0) |
| best_nomic = int(np.argmax(sims_nomic)) |
| conf_nomic = float(sims_nomic[best_nomic]) |
| gap_nomic = float(sims_nomic[best_nomic] - np.partition(sims_nomic, -2)[-2]) |
|
|
| rows.append({ |
| "gt": gt_class, |
| "jina_best_idx": best_jina, |
| "jina_conf": conf_jina, |
| "jina_gap": gap_jina, |
| "nomic_best_idx": best_nomic, |
| "nomic_conf": conf_nomic, |
| "nomic_gap": gap_nomic, |
| }) |
|
|
| if save_crops: |
| crop_pil.save(crops_no_label_dir / crop_name) |
| sc, sg = args.save_conf, args.save_gap |
| label_jina = ref_labels[best_jina] if (conf_jina >= sc and gap_jina >= sg) else f"unknown (gt:{gt_class})" |
| label_nomic = ref_labels_nomic[best_nomic] if (conf_nomic >= sc and gap_nomic >= sg) else f"unknown (gt:{gt_class})" |
| ann_jina = draw_label_on_image(crop_pil, label_jina, conf_jina) |
| ann_nomic = draw_label_on_image(crop_pil, label_nomic, conf_nomic) |
| ann_jina.save(jina_crops_dir / crop_name) |
| ann_nomic.save(nomic_crops_dir / crop_name) |
|
|
| if not rows: |
| raise SystemExit("No crops matched to GT (with our 4 classes). Check annotations and iou_min.") |
|
|
| print(f"[*] {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}") |
| if save_crops: |
| print(f"[*] Annotated crops saved to {jina_crops_dir} and {nomic_crops_dir}") |
| print(f"[*] Raw crops (no label) saved to {crops_no_label_dir}") |
| print(f"[*] Person/car grouping crops saved to {detection_crops_dir}") |
|
|
| |
| conf_candidates = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80] |
| gap_candidates = [0.02, 0.05, 0.08, 0.10] |
|
|
| def accuracy_jina(conf_t, gap_t): |
| correct = 0 |
| for r in rows: |
| pred = ref_labels[r["jina_best_idx"]] if (r["jina_conf"] >= conf_t and r["jina_gap"] >= gap_t) else "unknown" |
| if pred == r["gt"]: |
| correct += 1 |
| return correct / len(rows) |
|
|
| def accuracy_nomic(conf_t, gap_t): |
| correct = 0 |
| for r in rows: |
| pred = ref_labels_nomic[r["nomic_best_idx"]] if (r["nomic_conf"] >= conf_t and r["nomic_gap"] >= gap_t) else "unknown" |
| if pred == r["gt"]: |
| correct += 1 |
| return correct / len(rows) |
|
|
| best_jina_acc = -1 |
| best_jina_conf = best_jina_gap = None |
| for c in conf_candidates: |
| for g in gap_candidates: |
| acc = accuracy_jina(c, g) |
| if acc > best_jina_acc: |
| best_jina_acc = acc |
| best_jina_conf, best_jina_gap = c, g |
|
|
| best_nomic_acc = -1 |
| best_nomic_conf = best_nomic_gap = None |
| for c in conf_candidates: |
| for g in gap_candidates: |
| acc = accuracy_nomic(c, g) |
| if acc > best_nomic_acc: |
| best_nomic_acc = acc |
| best_nomic_conf, best_nomic_gap = c, g |
|
|
| |
| report_path = output_dir / "best_thresholds.txt" |
| with open(report_path, "w") as f: |
| f.write(f"Based on {len(rows)} crops with GT in {{cigarette, gun, knife, phone}}\n") |
| if save_crops: |
| f.write(f"Annotated crops: jina_crops/ and nomic_crops/ (conf>={args.save_conf}, gap>={args.save_gap})\n") |
| f.write("Raw crops (no label): crops/\n") |
| f.write("Person/car grouping only: detection_crops/\n") |
| f.write("\n") |
| f.write("Jina (best accuracy):\n") |
| f.write(f" conf_threshold = {best_jina_conf}\n") |
| f.write(f" gap_threshold = {best_jina_gap}\n") |
| f.write(f" accuracy = {best_jina_acc:.4f}\n\n") |
| f.write("Nomic (best accuracy):\n") |
| f.write(f" conf_threshold = {best_nomic_conf}\n") |
| f.write(f" gap_threshold = {best_nomic_gap}\n") |
| f.write(f" accuracy = {best_nomic_acc:.4f}\n") |
| print(f"\n[*] Best thresholds written to {report_path}") |
| print("\nJina best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format( |
| best_jina_conf, best_jina_gap, best_jina_acc)) |
| print("Nomic best: conf_threshold={}, gap_threshold={} -> accuracy={:.4f}".format( |
| best_nomic_conf, best_nomic_gap, best_nomic_acc)) |
|
|
| |
| import csv |
| csv_path = output_dir / "grid_search.csv" |
| with open(csv_path, "w", newline="") as f: |
| w = csv.writer(f) |
| w.writerow(["model", "conf_threshold", "gap_threshold", "accuracy"]) |
| for c in conf_candidates: |
| for g in gap_candidates: |
| w.writerow(["jina", c, g, f"{accuracy_jina(c, g):.4f}"]) |
| for c in conf_candidates: |
| for g in gap_candidates: |
| w.writerow(["nomic", c, g, f"{accuracy_nomic(c, g):.4f}"]) |
| print(f"[*] Full grid written to {csv_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|