Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import tempfile | |
from time import perf_counter | |
from typing import Any, List, Union | |
from doctr import models as models | |
from doctr.io import DocumentFile | |
from doctr.models import ocr_predictor | |
from PIL import Image | |
from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest | |
from inference.core.entities.requests.inference import InferenceRequest | |
from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse | |
from inference.core.entities.responses.inference import InferenceResponse | |
from inference.core.env import MODEL_CACHE_DIR | |
from inference.core.models.roboflow import RoboflowCoreModel | |
from inference.core.utils.image_utils import load_image | |
class DocTR(RoboflowCoreModel): | |
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): | |
"""Initializes the DocTR model. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
self.api_key = kwargs.get("api_key") | |
self.dataset_id = "doctr" | |
self.version_id = "default" | |
self.endpoint = model_id | |
model_id = model_id.lower() | |
os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec") | |
self.det_model = DocTRDet(api_key=kwargs.get("api_key")) | |
self.rec_model = DocTRRec(api_key=kwargs.get("api_key")) | |
os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True) | |
os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True) | |
shutil.copyfile( | |
f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt", | |
f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt", | |
) | |
shutil.copyfile( | |
f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt", | |
f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt", | |
) | |
self.model = ocr_predictor( | |
det_arch=self.det_model.version_id, | |
reco_arch=self.rec_model.version_id, | |
pretrained=True, | |
) | |
self.task_type = "ocr" | |
def clear_cache(self) -> None: | |
self.det_model.clear_cache() | |
self.rec_model.clear_cache() | |
def preprocess_image(self, image: Image.Image) -> Image.Image: | |
""" | |
DocTR pre-processes images as part of its inference pipeline. | |
Thus, no preprocessing is required here. | |
""" | |
pass | |
def infer_from_request( | |
self, request: DoctrOCRInferenceRequest | |
) -> DoctrOCRInferenceResponse: | |
t1 = perf_counter() | |
result = self.infer(**request.dict()) | |
return DoctrOCRInferenceResponse( | |
result=result, | |
time=perf_counter() - t1, | |
) | |
def infer(self, image: Any, **kwargs): | |
""" | |
Run inference on a provided image. | |
Args: | |
request (DoctrOCRInferenceRequest): The inference request. | |
Returns: | |
DoctrOCRInferenceResponse: The inference response. | |
""" | |
img = load_image(image) | |
with tempfile.NamedTemporaryFile(suffix=".jpg") as f: | |
image = Image.fromarray(img[0]) | |
image.save(f.name) | |
doc = DocumentFile.from_images([f.name]) | |
result = self.model(doc).export() | |
result = result["pages"][0]["blocks"] | |
result = [ | |
" ".join([word["value"] for word in line["words"]]) | |
for block in result | |
for line in block["lines"] | |
] | |
result = " ".join(result) | |
return result | |
def get_infer_bucket_file_list(self) -> list: | |
"""Get the list of required files for inference. | |
Returns: | |
list: A list of required files for inference, e.g., ["model.pt"]. | |
""" | |
return ["model.pt"] | |
class DocTRRec(RoboflowCoreModel): | |
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): | |
"""Initializes the DocTR model. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
pass | |
self.get_infer_bucket_file_list() | |
super().__init__(*args, model_id=model_id, **kwargs) | |
def get_infer_bucket_file_list(self) -> list: | |
"""Get the list of required files for inference. | |
Returns: | |
list: A list of required files for inference, e.g., ["model.pt"]. | |
""" | |
return ["model.pt"] | |
class DocTRDet(RoboflowCoreModel): | |
"""DocTR class for document Optical Character Recognition (OCR). | |
Attributes: | |
doctr: The DocTR model. | |
ort_session: ONNX runtime inference session. | |
""" | |
def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs): | |
"""Initializes the DocTR model. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
self.get_infer_bucket_file_list() | |
super().__init__(*args, model_id=model_id, **kwargs) | |
def get_infer_bucket_file_list(self) -> list: | |
"""Get the list of required files for inference. | |
Returns: | |
list: A list of required files for inference, e.g., ["model.pt"]. | |
""" | |
return ["model.pt"] | |