Spaces:
Paused
Paused
Upload 5 files
Browse files- app.py +43 -0
- arial.ttf +0 -0
- packages.txt +1 -0
- requirements.txt +5 -0
- server.py +176 -0
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
import socket
|
7 |
+
|
8 |
+
def start_server():
|
9 |
+
os.system("uvicorn inference_server:app --port 8080 --host 0.0.0.0 --workers 2")
|
10 |
+
st.session_state['server_started'] = True
|
11 |
+
|
12 |
+
def is_port_in_use(port):
|
13 |
+
import socket
|
14 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
15 |
+
return s.connect_ex(('0.0.0.0', port)) == 0
|
16 |
+
|
17 |
+
def recognize_passport(image_path):
|
18 |
+
files = {'image': open(image_path, 'rb')}
|
19 |
+
response = requests.post("http://0.0.0.0:8080/recognize_passport", files=files)
|
20 |
+
return response.json()
|
21 |
+
|
22 |
+
if 'server_started' not in st.session_state:
|
23 |
+
st.session_state['server_started'] = False
|
24 |
+
|
25 |
+
if not st.session_state['server_started']:
|
26 |
+
start_server()
|
27 |
+
|
28 |
+
st.title('Passport Recognition Demo')
|
29 |
+
|
30 |
+
image_path = st.file_uploader("Upload Passport Image", type=["jpg", "jpeg", "png"])
|
31 |
+
|
32 |
+
if image_path is not None:
|
33 |
+
st.image(image_path, caption="Uploaded Image.", use_column_width=True)
|
34 |
+
st.write("")
|
35 |
+
st.write("Classifying...")
|
36 |
+
|
37 |
+
with open("temp_image.jpg", "wb") as f:
|
38 |
+
f.write(image_path.read())
|
39 |
+
|
40 |
+
passport_info = recognize_passport("temp_image.jpg")
|
41 |
+
|
42 |
+
st.markdown(f'## Passport Recognition Results')
|
43 |
+
st.write(json.dumps(passport_info, indent=2))
|
arial.ttf
ADDED
Binary file (367 kB). View file
|
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
tesseract-ocr-all
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
paddlepaddle -i https://pypi.tuna.tsinghua.edu.cn/simple
|
2 |
+
fastapi
|
3 |
+
uvicorn
|
4 |
+
passporteye
|
5 |
+
paddleocr
|
server.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, status
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from paddleocr import PaddleOCR
|
5 |
+
from langchain.prompts import PromptTemplate
|
6 |
+
from langchain.chains import LLMChain
|
7 |
+
from passporteye import read_mrz
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from typing import Any, Optional, Dict, List
|
10 |
+
from huggingface_hub import InferenceClient
|
11 |
+
from langchain.llms.base import LLM
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
HF_token = os.getenv("apiToken")
|
16 |
+
|
17 |
+
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
18 |
+
hf_token = HF_token
|
19 |
+
kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
|
20 |
+
|
21 |
+
class KwArgsModel(BaseModel):
|
22 |
+
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
23 |
+
|
24 |
+
class CustomInferenceClient(LLM, KwArgsModel):
|
25 |
+
model_name: str
|
26 |
+
inference_client: InferenceClient
|
27 |
+
|
28 |
+
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
|
29 |
+
inference_client = InferenceClient(model=model_name, token=hf_token)
|
30 |
+
super().__init__(
|
31 |
+
model_name=model_name,
|
32 |
+
hf_token=hf_token,
|
33 |
+
kwargs=kwargs,
|
34 |
+
inference_client=inference_client
|
35 |
+
)
|
36 |
+
|
37 |
+
def _call(
|
38 |
+
self,
|
39 |
+
prompt: str,
|
40 |
+
stop: Optional[List[str]] = None
|
41 |
+
) -> str:
|
42 |
+
if stop is not None:
|
43 |
+
raise ValueError("stop kwargs are not permitted.")
|
44 |
+
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
|
45 |
+
response = ''.join(response_gen)
|
46 |
+
return response
|
47 |
+
|
48 |
+
@property
|
49 |
+
def _llm_type(self) -> str:
|
50 |
+
return "custom"
|
51 |
+
|
52 |
+
@property
|
53 |
+
def _identifying_params(self) -> dict:
|
54 |
+
return {"model_name": self.model_name}
|
55 |
+
|
56 |
+
app = FastAPI(title="Passport Recognition API")
|
57 |
+
|
58 |
+
app.add_middleware(
|
59 |
+
CORSMiddleware,
|
60 |
+
allow_origins=["*"],
|
61 |
+
allow_credentials=True,
|
62 |
+
allow_methods=["*"],
|
63 |
+
allow_headers=["*"],
|
64 |
+
)
|
65 |
+
|
66 |
+
ocr = PaddleOCR(use_angle_cls=True, lang='en')
|
67 |
+
template = """below is poorly read ocr result of a passport.
|
68 |
+
OCR Result:
|
69 |
+
{ocr_result}
|
70 |
+
|
71 |
+
Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format.
|
72 |
+
|
73 |
+
"countryName": "",
|
74 |
+
"dateOfBirth": "",
|
75 |
+
"dateOfExpiry": "",
|
76 |
+
"dateOfIssue": "",
|
77 |
+
"documentNumber": "",
|
78 |
+
"givenNames": "",
|
79 |
+
"name": "",
|
80 |
+
"surname": "",
|
81 |
+
"mrz": ""
|
82 |
+
|
83 |
+
json output:
|
84 |
+
"""
|
85 |
+
prompt = PromptTemplate(template=template, input_variables=["ocr_result"])
|
86 |
+
|
87 |
+
class MRZData(BaseModel):
|
88 |
+
date_of_birth: str
|
89 |
+
expiration_date: str
|
90 |
+
type: str
|
91 |
+
number: str
|
92 |
+
names: str
|
93 |
+
country: str
|
94 |
+
check_number: str
|
95 |
+
check_date_of_birth: str
|
96 |
+
check_expiration_date: str
|
97 |
+
check_composite: str
|
98 |
+
check_personal_number: str
|
99 |
+
valid_number: bool
|
100 |
+
valid_date_of_birth: bool
|
101 |
+
valid_expiration_date: bool
|
102 |
+
valid_composite: bool
|
103 |
+
valid_personal_number: bool
|
104 |
+
method: str
|
105 |
+
|
106 |
+
class OCRData(BaseModel):
|
107 |
+
countryName: str
|
108 |
+
dateOfBirth: str
|
109 |
+
dateOfExpiry: str
|
110 |
+
dateOfIssue: str
|
111 |
+
documentNumber: str
|
112 |
+
givenNames: str
|
113 |
+
name: str
|
114 |
+
surname: str
|
115 |
+
mrz: str
|
116 |
+
|
117 |
+
class ResponseData(BaseModel):
|
118 |
+
documentName: str
|
119 |
+
errorCode: int
|
120 |
+
mrz: MRZData
|
121 |
+
ocr: OCRData
|
122 |
+
status: str
|
123 |
+
|
124 |
+
|
125 |
+
def create_response_data(mrz, ocr_data):
|
126 |
+
return ResponseData(
|
127 |
+
documentName="Passport",
|
128 |
+
errorCode=0,
|
129 |
+
mrz=MRZData(**mrz),
|
130 |
+
ocr=OCRData(**ocr_data),
|
131 |
+
status="ok"
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
@app.post("/recognize_passport", response_model=ResponseData, status_code=status.HTTP_201_CREATED)
|
136 |
+
async def recognize_passport(image: UploadFile = File(...)):
|
137 |
+
"""Passport information extraction from a provided image file."""
|
138 |
+
try:
|
139 |
+
image_bytes = await image.read()
|
140 |
+
mrz = read_mrz(image_bytes)
|
141 |
+
|
142 |
+
img_path = 'image.jpg'
|
143 |
+
with open(img_path, 'wb') as f:
|
144 |
+
f.write(image_bytes)
|
145 |
+
|
146 |
+
result = ocr.ocr(img_path, cls=True)
|
147 |
+
json_result = []
|
148 |
+
for idx in range(len(result)):
|
149 |
+
res = result[idx]
|
150 |
+
for line in res:
|
151 |
+
coordinates, text_with_confidence = line
|
152 |
+
text, confidence = text_with_confidence
|
153 |
+
json_result.append({
|
154 |
+
'coordinates': coordinates,
|
155 |
+
'text': text,
|
156 |
+
'confidence': confidence
|
157 |
+
})
|
158 |
+
|
159 |
+
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
|
160 |
+
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
161 |
+
response_str = llm_chain.run(ocr_result=json_result)
|
162 |
+
response_str = response_str.rstrip("</s>")
|
163 |
+
#print(response_str)
|
164 |
+
|
165 |
+
ocr_data = json.loads(response_str)
|
166 |
+
|
167 |
+
return create_response_data(mrz.to_dict(), ocr_data)
|
168 |
+
|
169 |
+
except HTTPException as e:
|
170 |
+
raise e
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
raise HTTPException(
|
174 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
175 |
+
detail=f"Internal server error: {str(e)}"
|
176 |
+
) from e
|