SFEREWQW's picture
Upload 395 files
18e4106 verified
import os
import cv2
import torch
from pdf_extract_kit.registry import MODEL_REGISTRY
from pdf_extract_kit.utils.visualization import visualize_bbox
from pdf_extract_kit.dataset.dataset import ImageDataset
@MODEL_REGISTRY.register('layout_detection_yolo')
class LayoutDetectionYOLO:
def __init__(self, config):
"""
Initialize the LayoutDetectionYOLO class.
Args:
config (dict): Configuration dictionary containing model parameters.
"""
# Mapping from class IDs to class names
self.id_to_names = {
0: 'title',
1: 'plain text',
2: 'abandon',
3: 'figure',
4: 'figure_caption',
5: 'table',
6: 'table_caption',
7: 'table_footnote',
8: 'isolate_formula',
9: 'formula_caption'
}
# Load the YOLO model from the specified path
try:
from doclayout_yolo import YOLOv10
self.model = YOLOv10(config['model_path'])
except AttributeError:
from ultralytics import YOLO
self.model = YOLO(config['model_path'])
# Set model parameters
self.img_size = config.get('img_size', 1280)
self.conf_thres = config.get('conf_thres', 0.25)
self.iou_thres = config.get('iou_thres', 0.45)
self.visualize = config.get('visualize', False)
self.nc = config.get('nc', 10)
self.workers = config.get('workers', 8)
self.device = config.get('device', 'cpu')
if self.iou_thres > 0:
import torchvision
self.nms_func = torchvision.ops.nms
def predict(self, images, result_path, image_ids=None):
"""
Predict formulas in images.
Args:
images (list): List of images to be predicted.
result_path (str): Path to save the prediction results.
image_ids (list, optional): List of image IDs corresponding to the images.
Returns:
list: List of prediction results.
"""
results = []
for idx, image in enumerate(images):
result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False, device=self.device)[0]
if self.visualize:
if not os.path.exists(result_path):
os.makedirs(result_path)
boxes = result.__dict__['boxes'].xyxy
classes = result.__dict__['boxes'].cls
scores = result.__dict__['boxes'].conf
if self.iou_thres > 0:
indices = self.nms_func(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=self.iou_thres)
boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
if len(boxes.shape) == 1:
boxes = np.expand_dims(boxes, 0)
scores = np.expand_dims(scores, 0)
classes = np.expand_dims(classes, 0)
vis_result = visualize_bbox(image, boxes, classes, scores, self.id_to_names)
# Determine the base name of the image
if image_ids:
base_name = image_ids[idx]
else:
# base_name = os.path.basename(image)
base_name = os.path.splitext(os.path.basename(image))[0] # Remove file extension
result_name = f"{base_name}_layout.png"
# Save the visualized result
cv2.imwrite(os.path.join(result_path, result_name), vis_result)
results.append(result)
return results