Spaces:
Paused
Paused
import json | |
from fastapi import FastAPI, File, UploadFile, HTTPException, status | |
from fastapi.middleware.cors import CORSMiddleware | |
from paddleocr import PaddleOCR | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
from passporteye import read_mrz | |
from pydantic import BaseModel, Field | |
from typing import Any, Optional, Dict, List | |
from huggingface_hub import InferenceClient | |
from langchain.llms.base import LLM | |
HF_token = os.getenv("apiToken") | |
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
hf_token = HF_token | |
kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} | |
class KwArgsModel(BaseModel): | |
kwargs: Dict[str, Any] = Field(default_factory=dict) | |
class CustomInferenceClient(LLM, KwArgsModel): | |
model_name: str | |
inference_client: InferenceClient | |
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): | |
inference_client = InferenceClient(model=model_name, token=hf_token) | |
super().__init__( | |
model_name=model_name, | |
hf_token=hf_token, | |
kwargs=kwargs, | |
inference_client=inference_client | |
) | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None | |
) -> str: | |
if stop is not None: | |
raise ValueError("stop kwargs are not permitted.") | |
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) | |
response = ''.join(response_gen) | |
return response | |
def _llm_type(self) -> str: | |
return "custom" | |
def _identifying_params(self) -> dict: | |
return {"model_name": self.model_name} | |
app = FastAPI(title="Passport Recognition API") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
ocr = PaddleOCR(use_angle_cls=True, lang='en') | |
template = """below is poorly read ocr result of a passport. | |
OCR Result: | |
{ocr_result} | |
Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format. | |
"countryName": "", | |
"dateOfBirth": "", | |
"dateOfExpiry": "", | |
"dateOfIssue": "", | |
"documentNumber": "", | |
"givenNames": "", | |
"name": "", | |
"surname": "", | |
"mrz": "" | |
json output: | |
""" | |
prompt = PromptTemplate(template=template, input_variables=["ocr_result"]) | |
class MRZData(BaseModel): | |
date_of_birth: str | |
expiration_date: str | |
type: str | |
number: str | |
names: str | |
country: str | |
check_number: str | |
check_date_of_birth: str | |
check_expiration_date: str | |
check_composite: str | |
check_personal_number: str | |
valid_number: bool | |
valid_date_of_birth: bool | |
valid_expiration_date: bool | |
valid_composite: bool | |
valid_personal_number: bool | |
method: str | |
class OCRData(BaseModel): | |
countryName: str | |
dateOfBirth: str | |
dateOfExpiry: str | |
dateOfIssue: str | |
documentNumber: str | |
givenNames: str | |
name: str | |
surname: str | |
mrz: str | |
class ResponseData(BaseModel): | |
documentName: str | |
errorCode: int | |
mrz: MRZData | |
ocr: OCRData | |
status: str | |
def create_response_data(mrz, ocr_data): | |
return ResponseData( | |
documentName="Passport", | |
errorCode=0, | |
mrz=MRZData(**mrz), | |
ocr=OCRData(**ocr_data), | |
status="ok" | |
) | |
async def recognize_passport(image: UploadFile = File(...)): | |
"""Passport information extraction from a provided image file.""" | |
try: | |
image_bytes = await image.read() | |
mrz = read_mrz(image_bytes) | |
img_path = 'image.jpg' | |
with open(img_path, 'wb') as f: | |
f.write(image_bytes) | |
result = ocr.ocr(img_path, cls=True) | |
json_result = [] | |
for idx in range(len(result)): | |
res = result[idx] | |
for line in res: | |
coordinates, text_with_confidence = line | |
text, confidence = text_with_confidence | |
json_result.append({ | |
'coordinates': coordinates, | |
'text': text, | |
'confidence': confidence | |
}) | |
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs) | |
llm_chain = LLMChain(prompt=prompt, llm=llm) | |
response_str = llm_chain.run(ocr_result=json_result) | |
response_str = response_str.rstrip("</s>") | |
#print(response_str) | |
ocr_data = json.loads(response_str) | |
return create_response_data(mrz.to_dict(), ocr_data) | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Internal server error: {str(e)}" | |
) from e |