import os import shutil import tempfile from time import perf_counter from typing import Any, List, Union from doctr import models as models from doctr.io import DocumentFile from doctr.models import ocr_predictor from PIL import Image from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest from inference.core.entities.requests.inference import InferenceRequest from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse from inference.core.entities.responses.inference import InferenceResponse from inference.core.env import MODEL_CACHE_DIR from inference.core.models.roboflow import RoboflowCoreModel from inference.core.utils.image_utils import load_image class DocTR(RoboflowCoreModel): def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): """Initializes the DocTR model. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ self.api_key = kwargs.get("api_key") self.dataset_id = "doctr" self.version_id = "default" self.endpoint = model_id model_id = model_id.lower() os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec") self.det_model = DocTRDet(api_key=kwargs.get("api_key")) self.rec_model = DocTRRec(api_key=kwargs.get("api_key")) os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True) os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True) shutil.copyfile( f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt", f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt", ) shutil.copyfile( f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt", f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt", ) self.model = ocr_predictor( det_arch=self.det_model.version_id, reco_arch=self.rec_model.version_id, pretrained=True, ) self.task_type = "ocr" def clear_cache(self) -> None: self.det_model.clear_cache() self.rec_model.clear_cache() def preprocess_image(self, image: Image.Image) -> Image.Image: """ DocTR pre-processes images as part of its inference pipeline. Thus, no preprocessing is required here. """ pass def infer_from_request( self, request: DoctrOCRInferenceRequest ) -> DoctrOCRInferenceResponse: t1 = perf_counter() result = self.infer(**request.dict()) return DoctrOCRInferenceResponse( result=result, time=perf_counter() - t1, ) def infer(self, image: Any, **kwargs): """ Run inference on a provided image. Args: request (DoctrOCRInferenceRequest): The inference request. Returns: DoctrOCRInferenceResponse: The inference response. """ img = load_image(image) with tempfile.NamedTemporaryFile(suffix=".jpg") as f: image = Image.fromarray(img[0]) image.save(f.name) doc = DocumentFile.from_images([f.name]) result = self.model(doc).export() result = result["pages"][0]["blocks"] result = [ " ".join([word["value"] for word in line["words"]]) for block in result for line in block["lines"] ] result = " ".join(result) return result def get_infer_bucket_file_list(self) -> list: """Get the list of required files for inference. Returns: list: A list of required files for inference, e.g., ["model.pt"]. """ return ["model.pt"] class DocTRRec(RoboflowCoreModel): def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): """Initializes the DocTR model. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ pass self.get_infer_bucket_file_list() super().__init__(*args, model_id=model_id, **kwargs) def get_infer_bucket_file_list(self) -> list: """Get the list of required files for inference. Returns: list: A list of required files for inference, e.g., ["model.pt"]. """ return ["model.pt"] class DocTRDet(RoboflowCoreModel): """DocTR class for document Optical Character Recognition (OCR). Attributes: doctr: The DocTR model. ort_session: ONNX runtime inference session. """ def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs): """Initializes the DocTR model. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ self.get_infer_bucket_file_list() super().__init__(*args, model_id=model_id, **kwargs) def get_infer_bucket_file_list(self) -> list: """Get the list of required files for inference. Returns: list: A list of required files for inference, e.g., ["model.pt"]. """ return ["model.pt"]