File size: 2,187 Bytes
6ad7442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5653bd
 
 
 
 
 
 
 
4002914
6ad7442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0027d7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import plasma.functional as f
import plasma.huggingface as hf
import plasma.meta as meta

from .config import Config
from .model import YOLORunner
from .preprocesses import Preprocessor
from .post_processes import TableRestore

from ultralytics import YOLO
from .stamp_processing.detector import StampDetector
from .stamp_processing.callback import remove_potiential_table_fp, remove_box_by_idc


class Engine(f.Pipe):

    def __init__(self, cfg:Config=None, table_checkpoint=None, line_checkpoint=None, verbose=True):
        super().__init__(
            cfg = cfg or Config(),
            standard_width = cfg.STANDARD_WIDTH,
            preprocessor = Preprocessor(cfg.STANDARD_WIDTH, cfg.STANDARD_INTERPOLATION),
            post_processor = TableRestore(),
            table_detector = self._build_table_detector(cfg, table_checkpoint),
            stamp_detector = StampDetector(model_path=cfg.STAMP_DETECTION_CHECKPOINT, device=cfg.DEVICE)
        )
        self.stamp_detector.model.to(cfg.DEVICE)
    def run(self, image):
        # ratio = 1.0*self.standard_width/image.shape[1]
        # image = self.preprocessor(image)
        tables = self.table_detector(image)
        # tables = self.post_processor(tables, ratio)
        return tables
    
    def _build_table_detector(self, cfg: Config, checkpoint=None, verbose=False):
        if checkpoint is None:
            checkpoint = hf.download_file(cfg.TABLE_DETECTION_CHECKPOINT)
        model = YOLO(checkpoint).to(cfg.DEVICE)
        if cfg.USE_STAMP_DETECTION:
            model.add_callback("on_predict_postprocess_end", self._stamp_detection_callback)
        
        return YOLORunner(model, cfg.YOLO_IMAGE_SIZE, cfg.CONF_THRESHOLD, cfg.HEIGHT_EXPAND_RATIO, cfg.DEVICE, verbose)

    def _stamp_detection_callback(self, predictor):

        assert len(predictor.results) == 1, 'Only support batch size 1'
        preds = predictor.results[0]
        remove_idc = remove_potiential_table_fp(self.stamp_detector, preds.orig_img, preds.boxes.xyxy, self.cfg.STAMP_REMOVING_IOU_THRESHOLD)
        stamp_removed_preds = remove_box_by_idc(preds.boxes, remove_idc)
        predictor.results[0].boxes = stamp_removed_preds