import supervision as sv import numpy as np import cv2 import warnings import rfdetr from base_inference import BaseInference # Suppress PyTorch meshgrid warnings warnings.filterwarnings("ignore", category=UserWarning, message="torch.meshgrid") class RFDETRInference(BaseInference): """ A class to perform inference using RF-DETR models of different sizes. """ def __init__(self, version='small', pretrain_weights="./models/rfdetr_small/checkpoint_best_total.pth"): """ Initializes the RFDETR model. Args: version (str): Model version ('nano', 'small', 'medium', 'base', 'base2', 'large'). pretrain_weights (str): Path to the pretrained .pth weights file. Raises: ValueError: If an unsupported version is passed. """ # Map version names to RFDETR model classes model_cls = { 'nano': rfdetr.RFDETRNano, 'small': rfdetr.RFDETRSmall, 'medium': rfdetr.RFDETRMedium, 'base': rfdetr.RFDETRBase, 'base2': rfdetr.RFDETRBase, 'large': rfdetr.RFDETRLarge }.get(version) if not model_cls: raise ValueError(f"Unsupported version: {version}") self.model = model_cls(pretrain_weights=pretrain_weights) def infer(self, image, confidence=0.5, use_nms=False, nms_thresh=0.7): """ Perform inference on a single image. Args: image (np.ndarray): Input image (BGR format). confidence (float): Confidence threshold. use_nms (bool): Whether to apply Non-Maximum Suppression. nms_thresh (float): NMS IoU threshold. Returns: sv.Detections: Detection results including bounding boxes, class IDs, and confidences. """ # Convert grayscale to BGR if image is not None and (len(image.shape) == 2 or image.shape[2] == 1): image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) # Perform prediction if use_nms: detections = self.model.predict(image, threshold=confidence).with_nms( threshold=nms_thresh, class_agnostic=True) else: detections = self.model.predict(image, threshold=confidence) return sv.Detections( xyxy=np.array(detections.xyxy), class_id=np.array(detections.class_id), confidence=np.array(detections.confidence) )