OMG / inference /models /doctr /doctr_model.py
Fucius's picture
Upload 422 files
df6c67d verified
import os
import shutil
import tempfile
from time import perf_counter
from typing import Any, List, Union
from doctr import models as models
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from PIL import Image
from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import MODEL_CACHE_DIR
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image
class DocTR(RoboflowCoreModel):
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs):
"""Initializes the DocTR model.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
self.api_key = kwargs.get("api_key")
self.dataset_id = "doctr"
self.version_id = "default"
self.endpoint = model_id
model_id = model_id.lower()
os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec")
self.det_model = DocTRDet(api_key=kwargs.get("api_key"))
self.rec_model = DocTRRec(api_key=kwargs.get("api_key"))
os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True)
os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True)
shutil.copyfile(
f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt",
f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt",
)
shutil.copyfile(
f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt",
f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt",
)
self.model = ocr_predictor(
det_arch=self.det_model.version_id,
reco_arch=self.rec_model.version_id,
pretrained=True,
)
self.task_type = "ocr"
def clear_cache(self) -> None:
self.det_model.clear_cache()
self.rec_model.clear_cache()
def preprocess_image(self, image: Image.Image) -> Image.Image:
"""
DocTR pre-processes images as part of its inference pipeline.
Thus, no preprocessing is required here.
"""
pass
def infer_from_request(
self, request: DoctrOCRInferenceRequest
) -> DoctrOCRInferenceResponse:
t1 = perf_counter()
result = self.infer(**request.dict())
return DoctrOCRInferenceResponse(
result=result,
time=perf_counter() - t1,
)
def infer(self, image: Any, **kwargs):
"""
Run inference on a provided image.
Args:
request (DoctrOCRInferenceRequest): The inference request.
Returns:
DoctrOCRInferenceResponse: The inference response.
"""
img = load_image(image)
with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
image = Image.fromarray(img[0])
image.save(f.name)
doc = DocumentFile.from_images([f.name])
result = self.model(doc).export()
result = result["pages"][0]["blocks"]
result = [
" ".join([word["value"] for word in line["words"]])
for block in result
for line in block["lines"]
]
result = " ".join(result)
return result
def get_infer_bucket_file_list(self) -> list:
"""Get the list of required files for inference.
Returns:
list: A list of required files for inference, e.g., ["model.pt"].
"""
return ["model.pt"]
class DocTRRec(RoboflowCoreModel):
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs):
"""Initializes the DocTR model.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
pass
self.get_infer_bucket_file_list()
super().__init__(*args, model_id=model_id, **kwargs)
def get_infer_bucket_file_list(self) -> list:
"""Get the list of required files for inference.
Returns:
list: A list of required files for inference, e.g., ["model.pt"].
"""
return ["model.pt"]
class DocTRDet(RoboflowCoreModel):
"""DocTR class for document Optical Character Recognition (OCR).
Attributes:
doctr: The DocTR model.
ort_session: ONNX runtime inference session.
"""
def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs):
"""Initializes the DocTR model.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
self.get_infer_bucket_file_list()
super().__init__(*args, model_id=model_id, **kwargs)
def get_infer_bucket_file_list(self) -> list:
"""Get the list of required files for inference.
Returns:
list: A list of required files for inference, e.g., ["model.pt"].
"""
return ["model.pt"]