Spaces:
Paused
Paused
File size: 5,835 Bytes
47315cd 528f442 47315cd c5446e2 47315cd 528f442 47315cd eb28cbb 47315cd eb28cbb 47315cd eb28cbb 47315cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import json
from fastapi import FastAPI, File, UploadFile, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from paddleocr import PaddleOCR
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from passporteye import read_mrz
from pydantic.v1 import BaseModel as v1BaseModel
from pydantic.v1 import Field
from pydantic import BaseModel
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
import os
HF_token = os.getenv("apiToken")
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
hf_token = HF_token
kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
class KwArgsModel(v1BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
app = FastAPI(title="Passport Recognition API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
ocr = PaddleOCR(use_angle_cls=True, lang='en')
template = """below is poorly read ocr result of a passport.
OCR Result:
{ocr_result}
Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format.
"countryName": "",
"dateOfBirth": "",
"dateOfExpiry": "",
"dateOfIssue": "",
"documentNumber": "",
"givenNames": "",
"name": "",
"surname": "",
"mrz": ""
json output:
"""
prompt = PromptTemplate(template=template, input_variables=["ocr_result"])
class MRZData(BaseModel):
date_of_birth: str
expiration_date: str
type: str
number: str
names: str
country: str
check_number: str
check_date_of_birth: str
check_expiration_date: str
check_composite: str
check_personal_number: str
valid_number: bool
valid_date_of_birth: bool
valid_expiration_date: bool
valid_composite: bool
valid_personal_number: bool
method: str
class OCRData(BaseModel):
countryName: str
dateOfBirth: str
dateOfExpiry: str
dateOfIssue: str
documentNumber: str
givenNames: str
name: str
surname: str
mrz: str
class ResponseData(BaseModel):
documentName: str
errorCode: int
mrz: Optional[MRZData]
ocr: Optional[OCRData]
status: str
def create_response_data(mrz, ocr_data):
if not mrz and not ocr_data:
return ResponseData(
documentName="Passport",
errorCode=1,
mrz=None,
ocr=None,
status="No MRZ or OCR data available"
)
elif not mrz:
return ResponseData(
documentName="Passport",
errorCode=2,
mrz=None,
ocr=OCRData(**ocr_data),
status="PassportEYE did not find an MRZ"
)
elif not ocr_data:
return ResponseData(
documentName="Passport",
errorCode=3,
mrz=MRZData(**mrz),
ocr=None,
status="OCR result not available"
)
else:
return ResponseData(
documentName="Passport",
errorCode=0,
mrz=MRZData(**mrz),
ocr=OCRData(**ocr_data),
status="ok"
)
@app.post("/recognize_passport", response_model=ResponseData, status_code=status.HTTP_201_CREATED)
async def recognize_passport(image: UploadFile = File(...)):
"""Passport information extraction from a provided image file."""
try:
image_bytes = await image.read()
mrz = read_mrz(image_bytes)
img_path = 'image.jpg'
with open(img_path, 'wb') as f:
f.write(image_bytes)
result = ocr.ocr(img_path, cls=True)
json_result = []
for idx in range(len(result)):
res = result[idx]
for line in res:
coordinates, text_with_confidence = line
text, confidence = text_with_confidence
json_result.append({
'coordinates': coordinates,
'text': text,
'confidence': confidence
})
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
llm_chain = LLMChain(prompt=prompt, llm=llm)
response_str = llm_chain.run(ocr_result=json_result)
response_str = response_str.rstrip("</s>")
#print(response_str)
ocr_data = json.loads(response_str)
return create_response_data(mrz.to_dict() if mrz else None, ocr_data)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {str(e)}"
) from e |