|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from yolov5.utils.augmentations import letterbox |
|
from yolov5.utils.general import non_max_suppression, scale_boxes as scale_coords |
|
from abc import ABC, abstractmethod |
|
|
|
class BaseModel(ABC): |
|
@abstractmethod |
|
def pre_process(self, filename: str): |
|
"""Pre-process the input file and return it as a tensor.""" |
|
pass |
|
|
|
@abstractmethod |
|
def predict(self, input_data): |
|
"""Run inference on the pre-processed input and return predictions.""" |
|
pass |
|
|
|
class MegaDetectorModel(BaseModel): |
|
""" |
|
MegaDetectorModel loads the MegaDetector checkpoint from a Hugging Face repository, |
|
preprocesses input images, runs inference, and returns detections (label/confidence). |
|
|
|
The repository ID is the only input required. The model filename, class name, and weight file |
|
are all expected to match the repository's base name. For example, if the repository ID is |
|
"nkarthikeyan/MegaDetectorV5", then the model weight file should be "MegaDetectorV5.pt". |
|
""" |
|
|
|
def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None): |
|
self.device = torch.device(device) |
|
self.conf_thres = conf_thres |
|
self.iou_thres = iou_thres |
|
self.labels = None |
|
if labels_path and os.path.exists(labels_path): |
|
with open(labels_path, "r") as f: |
|
self.labels = [line.strip() for line in f.readlines()] |
|
self.model = None |
|
|
|
@classmethod |
|
def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs): |
|
""" |
|
Loads the model checkpoint from the given Hugging Face repository and returns an instance |
|
of MegaDetectorModel ready for inference. |
|
|
|
The repository's base name is used to derive the model weight filename. For example, if |
|
repo_id is "nkarthikeyan/MegaDetectorV5", then the weight file is expected to be "MegaDetectorV5.pt". |
|
|
|
Args: |
|
repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5"). |
|
device (str, optional): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'. |
|
|
|
Returns: |
|
MegaDetectorModel: An instance with the model loaded. |
|
""" |
|
instance = cls(device=device, **kwargs) |
|
|
|
model_name = repo_id.split("/")[-1] |
|
weight_filename = f"{model_name}.pt" |
|
model_path = hf_hub_download(repo_id=repo_id, filename=weight_filename) |
|
checkpoint = torch.load(model_path, map_location=instance.device) |
|
instance.model = checkpoint['model'].float().fuse().eval() |
|
if instance.device.type != 'cpu': |
|
instance.model.to(instance.device) |
|
return instance |
|
|
|
def pre_process(self, filename: str): |
|
image_bgr = cv2.imread(filename) |
|
if image_bgr is None: |
|
raise ValueError(f"Could not load image from path: {filename}") |
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
|
model_stride = int(self.model.stride.max()) |
|
processed = letterbox(image_rgb, new_shape=640, stride=model_stride, auto=False)[0] |
|
processed = processed.transpose(2, 0, 1) |
|
processed = np.ascontiguousarray(processed, dtype=np.float32) / 255.0 |
|
input_tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device) |
|
return input_tensor, image_rgb |
|
|
|
def predict(self, input_data): |
|
processed_tensor, original_rgb = input_data |
|
with torch.no_grad(): |
|
prediction = self.model(processed_tensor)[0] |
|
detections = non_max_suppression(prediction, conf_thres=self.conf_thres, iou_thres=self.iou_thres) |
|
results = [] |
|
if detections and detections[0] is not None: |
|
det = detections[0] |
|
det[:, :4] = scale_coords(processed_tensor.shape[2:], det[:, :4], original_rgb.shape).round() |
|
for *xyxy, conf, cls_idx in det.tolist(): |
|
label_idx = int(cls_idx) |
|
confidence = float(conf) |
|
if self.labels and 0 <= label_idx < len(self.labels): |
|
results.append((self.labels[label_idx], confidence)) |
|
else: |
|
results.append((label_idx, confidence)) |
|
return results |
|
|