|
|
from ultralytics import YOLO |
|
|
from glob import glob |
|
|
import matplotlib.pyplot as plt |
|
|
import cv2 |
|
|
import os |
|
|
from PIL import Image |
|
|
from ultralytics.engine.results import Results |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class detection: |
|
|
|
|
|
def __init__(self,model_path='detection.pt'): |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
model_path = os.path.join(current_dir , model_path ) |
|
|
self.model = YOLO(model_path) |
|
|
|
|
|
def get_distance(self,res): |
|
|
boxes = res[0].boxes.xywh.numpy() |
|
|
|
|
|
sorted_indices = np.lexsort((boxes[:, 0], boxes[:, 1])) |
|
|
sorted_boxes = boxes[sorted_indices] |
|
|
return sorted_boxes[:, 1], sorted_indices |
|
|
|
|
|
def handle_the_boxes(self,res, img, y_threshold=30): |
|
|
distance_sorted, sorted_indices = self.get_distance(res) |
|
|
PB = res[0].boxes.xyxy.numpy()[sorted_indices] |
|
|
same_object = [] |
|
|
current_line = [PB[0]] |
|
|
|
|
|
|
|
|
for i in range(1, len(PB)): |
|
|
prev_y = current_line[-1][1] |
|
|
current_y = PB[i][1] |
|
|
if abs(current_y - prev_y) > y_threshold: |
|
|
|
|
|
current_line = sorted(current_line, key=lambda x: x[0] , reverse=True) |
|
|
same_object.append(current_line) |
|
|
current_line = [PB[i]] |
|
|
else: |
|
|
current_line.append(PB[i]) |
|
|
|
|
|
|
|
|
if current_line: |
|
|
current_line = sorted(current_line, key=lambda x: x[0]) |
|
|
same_object.append(current_line) |
|
|
|
|
|
|
|
|
return [ |
|
|
[self.words_pixels(img, box) for box in line] |
|
|
for line in same_object |
|
|
] |
|
|
|
|
|
|
|
|
def words_pixels(self,img, xyxy): |
|
|
xmin, ymin, xmax, ymax = xyxy.tolist() |
|
|
return img[int(ymin):int(ymax)+1, int(xmin):int(xmax)+1] |
|
|
|
|
|
def full_pipeline(self,image,show=False): |
|
|
|
|
|
if isinstance(image, str): |
|
|
img = cv2.imread(image) |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = image |
|
|
img = image |
|
|
|
|
|
res = self.model(image) |
|
|
|
|
|
if show: |
|
|
res[0].show() |
|
|
|
|
|
|
|
|
return self.handle_the_boxes(res , img) |
|
|
|
|
|
|
|
|
|
|
|
|