nguyenp99's picture
Update stamp_processing/detector.py
d749e18 verified
raw
history blame
3.15 kB
import os
from functools import partial
from typing import List, Union
import numpy as np
import numpy.typing as npt
import torch
import plasma.huggingface as hf
from .module.yolov5.yolo_utils.datasets import letterbox
from .module.yolov5.yolo_utils.general import non_max_suppression, scale_coords
from .preprocess import create_batch, process_image
from .utils import (
DETECTOR_WEIGHT_ID,
check_image_shape,
load_yolo_model,
)
class StampDetector:
def __init__(
self, model_path: Union[str, None] = None, device: str = "cpu", conf_thres: float = 0.5, iou_thres: float = 0.3
) -> None:
"""Create an object for stamp detection"""
# assert device == "cpu", "Currently only support cpu inference"
checkpoint = hf.download_file(model_path)
print(model_path)
print(checkpoint)
self.device = device
self.model, self.stride = load_yolo_model(checkpoint, device=device)
self.img_size = 640
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.process_func_ = partial(process_image, device=device)
def __call__(self, image_list: Union[List[npt.NDArray], npt.NDArray]) -> List[npt.NDArray]:
"""Returns a list of bounding boxes [xmin, ymin, xmax, ymax] for each image in image_list
Each element in the list is a numpy array of shape N x 4
Args:
image_list (Union[List[npt.NDArray], npt.NDArray]): input images
Returns:
[List[np.ndarray]]: output bounding boxes
"""
if not isinstance(image_list, (np.ndarray, list)):
raise TypeError("Invalid Type: Input must be of type list or np.ndarray")
if len(image_list) > 0:
check_image_shape(image_list[0])
else:
return []
return self.__detect(image_list) # type: ignore
def __detect(self, image_list): # type: ignore
"""
Use __call__ method
"""
batches, indices = create_batch(image_list, set(list(x.shape for x in image_list)))
predictions = []
for origin_images in batches:
images = [letterbox(x, 640, stride=32)[0] for x in origin_images] # type: ignore
images = list(map(self.process_func_, images))
tensor = torch.stack(images).half()
with torch.no_grad():
pred = self.model(tensor)[0]
all_boxes = []
pred = non_max_suppression(pred, 0.3, 0.30, classes=0, agnostic=1) # type: ignore
for idx, det in enumerate(pred):
if len(det):
det[:, :4] = scale_coords(images[idx].shape[1:], det[:, :4], origin_images[0].shape) # type: ignore
det = det[:, :4].round()
all_boxes.append(det.cpu().numpy().astype("int").tolist())
else:
all_boxes.append([])
predictions.extend(all_boxes)
z = zip(predictions, indices)
sorted_result = sorted(z, key=lambda x: x[1])
predictions, _ = zip(*sorted_result)
return list(predictions)