Spaces:
Paused
Paused
initial commit
Browse files- .gitattributes +0 -35
- .gitignore +2 -0
- Dockerfile +16 -0
- __pycache__/app.cpython-312.pyc +0 -0
- app.py +226 -0
- layoutlmv3FineTuning/Copy of annotate_image.py +51 -0
- layoutlmv3FineTuning/Copy of run_inference.py +32 -0
- layoutlmv3FineTuning/Layoutlm_inference/__init__.py +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/__init__.cpython-312.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/__init__.cpython-39.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/annotate_image.cpython-312.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/annotate_image.cpython-39.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/inference_handler.cpython-312.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/inference_handler.cpython-39.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/model_base_path.cpython-312.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/model_base_path.cpython-39.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/ocr.cpython-312.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/ocr.cpython-39.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/utils.cpython-312.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/__pycache__/utils.cpython-39.pyc +0 -0
- layoutlmv3FineTuning/Layoutlm_inference/annotate_image.py +54 -0
- layoutlmv3FineTuning/Layoutlm_inference/inference_handler.py +268 -0
- layoutlmv3FineTuning/Layoutlm_inference/model_base_path.py +2 -0
- layoutlmv3FineTuning/Layoutlm_inference/ocr.py +144 -0
- layoutlmv3FineTuning/Layoutlm_inference/utils.py +68 -0
- layoutlmv3FineTuning/README.md +3 -0
- layoutlmv3FineTuning/inference_handler_modified.py +213 -0
- layoutlmv3FineTuning/preprocess.py +163 -0
- layoutlmv3FineTuning/run_inference.py +31 -0
- layoutlmv3FineTuning/run_inferenceM.py +32 -0
- multiple_request.py +56 -0
- requirements.txt +9 -0
- sample.py +22 -0
- titanium-scope-436311-t3-966373f5aa2f.json +13 -0
- uploads/aadhar/test_one.jpg +0 -0
- uploads/aadhar/test_two.jpg +0 -0
- uploads/cheque/0f81678a.jpeg +0 -0
- uploads/gst/0a52fbcb_page3_image_0.jpg +0 -0
- uploads/pan/6ea33087.jpeg +0 -0
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
dependencies/
|
Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
RUN useradd -m -u 1000 user
|
7 |
+
USER user
|
8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
9 |
+
|
10 |
+
WORKDIR /app
|
11 |
+
|
12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
14 |
+
|
15 |
+
COPY --chown=user . /app
|
16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
__pycache__/app.cpython-312.pyc
ADDED
Binary file (6.58 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from typing import Dict
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import logging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
10 |
+
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
import os
|
13 |
+
|
14 |
+
# Load .env file
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
# Access variables
|
18 |
+
dummy_key = os.getenv("dummy_key")
|
19 |
+
HUGGINGFACE_AUTH_TOKEN = dummy_key
|
20 |
+
|
21 |
+
|
22 |
+
# Hugging Face model and token
|
23 |
+
aadhar_model = "AuditEdge/doc_ocr_a" # Replace with your fine-tuned model if applicable
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
print(f"Using device: {device}")
|
26 |
+
|
27 |
+
# Load the processor (tokenizer + image processor)
|
28 |
+
processor_aadhar = LayoutLMv3Processor.from_pretrained(
|
29 |
+
aadhar_model,
|
30 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
31 |
+
)
|
32 |
+
aadhar_model = LayoutLMv3ForTokenClassification.from_pretrained(
|
33 |
+
aadhar_model,
|
34 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
35 |
+
)
|
36 |
+
aadhar_model = aadhar_model.to(device)
|
37 |
+
|
38 |
+
# pan model
|
39 |
+
pan_model = "AuditEdge/doc_ocr_p" # Replace with your fine-tuned model if applicable
|
40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
+
print(f"Using device: {device}")
|
42 |
+
|
43 |
+
# Load the processor (tokenizer + image processor)
|
44 |
+
processor_pan = LayoutLMv3Processor.from_pretrained(
|
45 |
+
pan_model,
|
46 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
47 |
+
)
|
48 |
+
pan_model = LayoutLMv3ForTokenClassification.from_pretrained(
|
49 |
+
pan_model,
|
50 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
51 |
+
)
|
52 |
+
pan_model = pan_model.to(device)
|
53 |
+
|
54 |
+
#
|
55 |
+
# gst model
|
56 |
+
gst_model = "AuditEdge/doc_ocr_new_g" # Replace with your fine-tuned model if applicable
|
57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
print(f"Using device: {device}")
|
59 |
+
|
60 |
+
# Load the processor (tokenizer + image processor)
|
61 |
+
processor_gst = LayoutLMv3Processor.from_pretrained(
|
62 |
+
gst_model,
|
63 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
64 |
+
)
|
65 |
+
gst_model = LayoutLMv3ForTokenClassification.from_pretrained(
|
66 |
+
gst_model,
|
67 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
68 |
+
)
|
69 |
+
gst_model = gst_model.to(device)
|
70 |
+
|
71 |
+
#cheque model
|
72 |
+
|
73 |
+
cheque_model = "AuditEdge/doc_ocr_new_c" # Replace with your fine-tuned model if applicable
|
74 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
75 |
+
print(f"Using device: {device}")
|
76 |
+
|
77 |
+
# Load the processor (tokenizer + image processor)
|
78 |
+
processor_cheque = LayoutLMv3Processor.from_pretrained(
|
79 |
+
cheque_model,
|
80 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
81 |
+
)
|
82 |
+
cheque_model = LayoutLMv3ForTokenClassification.from_pretrained(
|
83 |
+
cheque_model,
|
84 |
+
use_auth_token=HUGGINGFACE_AUTH_TOKEN
|
85 |
+
)
|
86 |
+
cheque_model = cheque_model.to(device)
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
# Verify model and processor are loaded
|
94 |
+
print("Model and processor loaded successfully!")
|
95 |
+
print(f"Model is on device: {next(aadhar_model.parameters()).device}")
|
96 |
+
|
97 |
+
|
98 |
+
# Import inference modules
|
99 |
+
from layoutlmv3FineTuning.Layoutlm_inference.ocr import prepare_batch_for_inference
|
100 |
+
from layoutlmv3FineTuning.Layoutlm_inference.inference_handler import handle
|
101 |
+
|
102 |
+
# Create FastAPI instance
|
103 |
+
app = FastAPI(debug=True)
|
104 |
+
|
105 |
+
# Enable CORS
|
106 |
+
app.add_middleware(
|
107 |
+
CORSMiddleware,
|
108 |
+
allow_origins=["*"],
|
109 |
+
allow_credentials=True,
|
110 |
+
allow_methods=["*"],
|
111 |
+
allow_headers=["*"],
|
112 |
+
)
|
113 |
+
|
114 |
+
# Configure directories
|
115 |
+
UPLOAD_FOLDER = './uploads/'
|
116 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Ensure the main upload folder exists
|
117 |
+
|
118 |
+
UPLOAD_DIRS = {
|
119 |
+
"aadhar_file": "uploads/aadhar/",
|
120 |
+
"pan_file": "uploads/pan/",
|
121 |
+
"cheque_file": "uploads/cheque/",
|
122 |
+
"gst_file": "uploads/gst/",
|
123 |
+
}
|
124 |
+
|
125 |
+
# Ensure individual directories exist
|
126 |
+
for dir_path in UPLOAD_DIRS.values():
|
127 |
+
os.makedirs(dir_path, exist_ok=True)
|
128 |
+
|
129 |
+
# Logger configuration
|
130 |
+
logging.basicConfig(level=logging.INFO)
|
131 |
+
|
132 |
+
# Perform Inference
|
133 |
+
def perform_inference(file_paths: Dict[str, str]):
|
134 |
+
# Dictionary to map document types to their respective model directories
|
135 |
+
model_dirs = {
|
136 |
+
"aadhar_file": aadhar_model,
|
137 |
+
"pan_file": pan_model,
|
138 |
+
"cheque_file": cheque_model,
|
139 |
+
"gst_file": gst_model,
|
140 |
+
}
|
141 |
+
|
142 |
+
# Dictionary to store results for each document type
|
143 |
+
inference_results = {}
|
144 |
+
|
145 |
+
# Loop through the file paths and perform inference
|
146 |
+
for doc_type, file_path in file_paths.items():
|
147 |
+
if doc_type in model_dirs:
|
148 |
+
print(f"Processing {doc_type} using model at {model_dirs[doc_type]}")
|
149 |
+
|
150 |
+
# Prepare batch for inference
|
151 |
+
images_path = [file_path]
|
152 |
+
inference_batch = prepare_batch_for_inference(images_path)
|
153 |
+
|
154 |
+
# Prepare context for the specific document type
|
155 |
+
# context = {"model_dir": model_dirs[doc_type]}
|
156 |
+
# context = aadhar_model
|
157 |
+
if doc_type == "aadhar_file":
|
158 |
+
context = aadhar_model
|
159 |
+
processor = processor_aadhar
|
160 |
+
name = "aadhar"
|
161 |
+
|
162 |
+
if doc_type == "pan_file":
|
163 |
+
context = pan_model
|
164 |
+
processor = processor_pan
|
165 |
+
name = "pan"
|
166 |
+
|
167 |
+
if doc_type == "gst_file":
|
168 |
+
context = gst_model
|
169 |
+
processor = processor_gst
|
170 |
+
name = "gst"
|
171 |
+
|
172 |
+
if doc_type == "cheque_file":
|
173 |
+
context = cheque_model
|
174 |
+
processor = processor_cheque
|
175 |
+
name = "cheque"
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
# Perform inference (replace `handle` with your actual function)
|
180 |
+
result = handle(inference_batch, context,processor,name)
|
181 |
+
|
182 |
+
# Store the result
|
183 |
+
inference_results[doc_type] = result
|
184 |
+
else:
|
185 |
+
print(f"Model directory not found for {doc_type}. Skipping.")
|
186 |
+
|
187 |
+
return inference_results
|
188 |
+
|
189 |
+
# Routes
|
190 |
+
@app.get("/")
|
191 |
+
def greet_json():
|
192 |
+
return {"Hello": "World!"}
|
193 |
+
|
194 |
+
@app.post("/api/aadhar_ocr")
|
195 |
+
async def aadhar_ocr(
|
196 |
+
aadhar_file: UploadFile = File(None),
|
197 |
+
pan_file: UploadFile = File(None),
|
198 |
+
cheque_file: UploadFile = File(None),
|
199 |
+
gst_file: UploadFile = File(None),
|
200 |
+
):
|
201 |
+
try:
|
202 |
+
# Handle file uploads
|
203 |
+
file_paths = {}
|
204 |
+
for file_type, folder in UPLOAD_DIRS.items():
|
205 |
+
file = locals()[file_type] # Dynamically access the file arguments
|
206 |
+
if file:
|
207 |
+
# Save the file in the respective directory
|
208 |
+
file_path = os.path.join(folder, file.filename)
|
209 |
+
with open(file_path, "wb") as buffer:
|
210 |
+
shutil.copyfileobj(file.file, buffer)
|
211 |
+
file_paths[file_type] = file_path
|
212 |
+
|
213 |
+
# Log received files
|
214 |
+
logging.info(f"Received files: {list(file_paths.keys())}")
|
215 |
+
print("file_paths",file_paths)
|
216 |
+
import sys
|
217 |
+
# sys.exit()
|
218 |
+
|
219 |
+
# Perform inference
|
220 |
+
result = perform_inference(file_paths)
|
221 |
+
|
222 |
+
return {"status": "success", "result": result}
|
223 |
+
|
224 |
+
except Exception as e:
|
225 |
+
logging.error(f"Error processing files: {e}")
|
226 |
+
raise HTTPException(status_code=500, detail="Internal Server Error")
|
layoutlmv3FineTuning/Copy of annotate_image.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageDraw, ImageFont
|
3 |
+
from .utils import image_label_2_color
|
4 |
+
|
5 |
+
|
6 |
+
def get_flattened_output(docs):
|
7 |
+
flattened_output = []
|
8 |
+
annotation_key = 'output'
|
9 |
+
for doc in docs:
|
10 |
+
flattened_output_item = {annotation_key: []}
|
11 |
+
doc_annotation = doc[annotation_key]
|
12 |
+
for i, span in enumerate(doc_annotation):
|
13 |
+
if len(span['words']) > 1:
|
14 |
+
for span_chunk in span['words']:
|
15 |
+
flattened_output_item[annotation_key].append(
|
16 |
+
{
|
17 |
+
'label': span['label'],
|
18 |
+
'text': span_chunk['text'],
|
19 |
+
'words': [span_chunk]
|
20 |
+
}
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
flattened_output_item[annotation_key].append(span)
|
24 |
+
flattened_output.append(flattened_output_item)
|
25 |
+
return flattened_output
|
26 |
+
|
27 |
+
|
28 |
+
def annotate_image(image_path, annotation_object):
|
29 |
+
img = None
|
30 |
+
image = Image.open(image_path).convert('RGBA')
|
31 |
+
tmp = image.copy()
|
32 |
+
label2color = image_label_2_color(annotation_object)
|
33 |
+
overlay = Image.new('RGBA', tmp.size, (0, 0, 0)+(0,))
|
34 |
+
draw = ImageDraw.Draw(overlay)
|
35 |
+
font = ImageFont.load_default()
|
36 |
+
|
37 |
+
predictions = [span['label'] for span in annotation_object['output']]
|
38 |
+
boxes = [span['words'][0]['box'] for span in annotation_object['output']]
|
39 |
+
for prediction, box in zip(predictions, boxes):
|
40 |
+
draw.rectangle(box, outline=label2color[prediction],
|
41 |
+
width=3, fill=label2color[prediction]+(int(255*0.33),))
|
42 |
+
draw.text((box[0] + 10, box[1] - 10), text=prediction,
|
43 |
+
fill=label2color[prediction], font=font)
|
44 |
+
|
45 |
+
img = Image.alpha_composite(tmp, overlay)
|
46 |
+
img = img.convert("RGB")
|
47 |
+
|
48 |
+
image_name = os.path.basename(image_path)
|
49 |
+
image_name = image_name[:image_name.find('.')]
|
50 |
+
save_path = os.path.join('/content', f'{image_name}_annotated.jpg')
|
51 |
+
img.save(save_path)
|
layoutlmv3FineTuning/Copy of run_inference.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from asyncio.log import logger
|
3 |
+
from Layoutlm_inference.ocr import prepare_batch_for_inference
|
4 |
+
from Layoutlm_inference.inference_handler import handle
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
# try:
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--model_path", type=str, required=True)
|
14 |
+
parser.add_argument("--image_path", type=str, required=True) # single image path
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
# Expecting a single image file
|
18 |
+
image_path = args.image_path
|
19 |
+
|
20 |
+
# Ensure the file exists before processing
|
21 |
+
if not os.path.isfile(image_path):
|
22 |
+
raise FileNotFoundError(f"The provided image path does not exist: {image_path}")
|
23 |
+
|
24 |
+
# Prepare batch for a single image
|
25 |
+
inference_batch = prepare_batch_for_inference([image_path]) # pass as a list
|
26 |
+
context = {"model_dir": args.model_path}
|
27 |
+
|
28 |
+
# Handle the inference
|
29 |
+
handle(inference_batch, context)
|
30 |
+
|
31 |
+
|
32 |
+
|
layoutlmv3FineTuning/Layoutlm_inference/__init__.py
ADDED
File without changes
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (184 Bytes). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (147 Bytes). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/annotate_image.cpython-312.pyc
ADDED
Binary file (2.58 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/annotate_image.cpython-39.pyc
ADDED
Binary file (1.73 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/inference_handler.cpython-312.pyc
ADDED
Binary file (12.3 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/inference_handler.cpython-39.pyc
ADDED
Binary file (7.22 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/model_base_path.cpython-312.pyc
ADDED
Binary file (314 Bytes). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/model_base_path.cpython-39.pyc
ADDED
Binary file (267 Bytes). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/ocr.cpython-312.pyc
ADDED
Binary file (5.27 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/ocr.cpython-39.pyc
ADDED
Binary file (3.52 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (3.19 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (2.52 kB). View file
|
|
layoutlmv3FineTuning/Layoutlm_inference/annotate_image.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageDraw, ImageFont
|
3 |
+
from .utils import image_label_2_color
|
4 |
+
|
5 |
+
|
6 |
+
def get_flattened_output(docs):
|
7 |
+
flattened_output = []
|
8 |
+
annotation_key = 'output'
|
9 |
+
for doc in docs:
|
10 |
+
flattened_output_item = {annotation_key: []}
|
11 |
+
doc_annotation = doc[annotation_key]
|
12 |
+
for i, span in enumerate(doc_annotation):
|
13 |
+
if len(span['words']) > 1:
|
14 |
+
for span_chunk in span['words']:
|
15 |
+
flattened_output_item[annotation_key].append(
|
16 |
+
{
|
17 |
+
'label': span['label'],
|
18 |
+
'text': span_chunk['text'],
|
19 |
+
'words': [span_chunk]
|
20 |
+
}
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
flattened_output_item[annotation_key].append(span)
|
24 |
+
flattened_output.append(flattened_output_item)
|
25 |
+
return flattened_output
|
26 |
+
|
27 |
+
|
28 |
+
def annotate_image(image_path, annotation_object):
|
29 |
+
print("image_path",image_path)
|
30 |
+
img = None
|
31 |
+
image = Image.open(image_path).convert('RGBA')
|
32 |
+
tmp = image.copy()
|
33 |
+
label2color = image_label_2_color(annotation_object)
|
34 |
+
overlay = Image.new('RGBA', tmp.size, (0, 0, 0)+(0,))
|
35 |
+
draw = ImageDraw.Draw(overlay)
|
36 |
+
font = ImageFont.load_default()
|
37 |
+
|
38 |
+
predictions = [span['label'] for span in annotation_object['output']]
|
39 |
+
boxes = [span['words'][0]['box'] for span in annotation_object['output']]
|
40 |
+
for prediction, box in zip(predictions, boxes):
|
41 |
+
print("prediction",prediction)
|
42 |
+
print("box",box)
|
43 |
+
draw.rectangle(box, outline=label2color[prediction],
|
44 |
+
width=3, fill=label2color[prediction]+(int(255*0.33),))
|
45 |
+
draw.text((box[0] + 10, box[1] - 10), text=prediction,
|
46 |
+
fill=label2color[prediction], font=font)
|
47 |
+
|
48 |
+
# img = Image.alpha_composite(tmp, overlay)
|
49 |
+
# img = img.convert("RGB")
|
50 |
+
|
51 |
+
# image_name = os.path.basename(image_path)
|
52 |
+
# image_name = image_name[:image_name.find('.')]
|
53 |
+
# save_path = os.path.join('/home/ec2-user/sample_project/inferred_images', f'{image_name}_annotated_1.jpg')
|
54 |
+
# img.save(save_path)
|
layoutlmv3FineTuning/Layoutlm_inference/inference_handler.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import load_model,load_processor,normalize_box,compare_boxes,adjacent
|
2 |
+
from .model_base_path import LAYOUTLMV2_BASE_PATH,LAYOUTLMV3_BASE_PATH
|
3 |
+
from .annotate_image import get_flattened_output,annotate_image
|
4 |
+
from PIL import Image,ImageDraw, ImageFont
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ModelHandler(object):
|
15 |
+
"""
|
16 |
+
A base Model handler implementation.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
# self.model = None
|
21 |
+
# self.model_dir = None
|
22 |
+
# self.device = 'cpu'
|
23 |
+
# self.error = None
|
24 |
+
# self._context = None
|
25 |
+
# self._batch_size = 0
|
26 |
+
self.initialized = False
|
27 |
+
self._raw_input_data = None
|
28 |
+
self._processed_data = None
|
29 |
+
self._images_size = None
|
30 |
+
|
31 |
+
def initialize(self, context,preprocessor,name):
|
32 |
+
"""
|
33 |
+
Initialize model. This will be called during model loading time
|
34 |
+
:param context: Initial context contains model server system properties.
|
35 |
+
:return:
|
36 |
+
"""
|
37 |
+
logger.info("Loading transformer model")
|
38 |
+
|
39 |
+
# self._context = context
|
40 |
+
# properties = self._context
|
41 |
+
# self._batch_size = properties["batch_size"] or 1
|
42 |
+
# self.model_dir = properties.get("model_dir")
|
43 |
+
self.name = name
|
44 |
+
self.model = context
|
45 |
+
self.preprocessor = preprocessor
|
46 |
+
self.initialized = True
|
47 |
+
|
48 |
+
def preprocess(self, batch):
|
49 |
+
"""
|
50 |
+
Transform raw input into model input data.
|
51 |
+
:param batch: list of raw requests, should match batch size
|
52 |
+
:return: list of preprocessed model input data
|
53 |
+
"""
|
54 |
+
# Take the input data and pre-process it make it inference ready
|
55 |
+
# assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
|
56 |
+
inference_dict = batch
|
57 |
+
|
58 |
+
print("inference_dict",inference_dict)
|
59 |
+
self._raw_input_data = inference_dict
|
60 |
+
# model_name_or_path = None
|
61 |
+
# if 'v2' in self.model.config.architectures[0]:
|
62 |
+
# model_name_or_path = LAYOUTLMV2_BASE_PATH
|
63 |
+
# elif 'v3' in self.model.config.architectures[0]:
|
64 |
+
# model_name_or_path = LAYOUTLMV3_BASE_PATH
|
65 |
+
# else:
|
66 |
+
# raise ValueError('invalid model architecture, please make sure the model is either Layoutlmv2 or Layoutlmv3')
|
67 |
+
# processor = load_processor(model_name_or_path)
|
68 |
+
processor = self.preprocessor
|
69 |
+
|
70 |
+
images = [Image.open(path).convert("RGB")
|
71 |
+
for path in inference_dict['image_path']]
|
72 |
+
self._images_size = [img.size for img in images]
|
73 |
+
words = inference_dict['words']
|
74 |
+
boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
|
75 |
+
for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]
|
76 |
+
encoded_inputs = processor(
|
77 |
+
images, words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
|
78 |
+
self._processed_data = encoded_inputs
|
79 |
+
encoded_inputs = {key: val.to(device) for key, val in encoded_inputs.items()}
|
80 |
+
print("encoded_inputs",encoded_inputs)
|
81 |
+
|
82 |
+
return encoded_inputs
|
83 |
+
|
84 |
+
def load(self, model_dir):
|
85 |
+
"""The load handler is responsible for loading the hunggingface transformer model.
|
86 |
+
Returns:
|
87 |
+
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
|
88 |
+
"""
|
89 |
+
# TODO model dir should be microsoft/layoutlmv2-base-uncased
|
90 |
+
model = load_model(model_dir)
|
91 |
+
return model
|
92 |
+
|
93 |
+
def inference(self, model_input):
|
94 |
+
"""
|
95 |
+
Internal inference methods
|
96 |
+
:param model_input: transformed model input data
|
97 |
+
:return: list of inference output in NDArray
|
98 |
+
"""
|
99 |
+
# TODO load the model state_dict before running the inference
|
100 |
+
# Do some inference call to engine here and return output
|
101 |
+
with torch.no_grad():
|
102 |
+
inference_outputs = self.model(**model_input)
|
103 |
+
predictions = inference_outputs.logits.argmax(-1).tolist()
|
104 |
+
print("these are predictions",predictions)
|
105 |
+
results = []
|
106 |
+
for i in range(len(predictions)):
|
107 |
+
tmp = dict()
|
108 |
+
tmp[f'output_{i}'] = predictions[i]
|
109 |
+
results.append(tmp)
|
110 |
+
|
111 |
+
return [results]
|
112 |
+
|
113 |
+
def postprocess(self, inference_output):
|
114 |
+
print("self._raw_input_data['words']",self._raw_input_data['words'])
|
115 |
+
print("inference_output",inference_output)
|
116 |
+
|
117 |
+
|
118 |
+
docs = []
|
119 |
+
k = 0
|
120 |
+
for page, doc_words in enumerate(self._raw_input_data['words']):
|
121 |
+
print(page,doc_words)
|
122 |
+
doc_list = []
|
123 |
+
width, height = self._images_size[page]
|
124 |
+
for i, doc_word in enumerate(doc_words, start=0):
|
125 |
+
word_tagging = None
|
126 |
+
word_labels = []
|
127 |
+
word = dict()
|
128 |
+
word['id'] = k
|
129 |
+
k += 1
|
130 |
+
word['text'] = doc_word
|
131 |
+
word['pageNum'] = page + 1
|
132 |
+
word['box'] = self._raw_input_data['bboxes'][page][i]
|
133 |
+
_normalized_box = normalize_box(
|
134 |
+
self._raw_input_data['bboxes'][page][i], width, height)
|
135 |
+
for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
|
136 |
+
if compare_boxes(box, _normalized_box):
|
137 |
+
if self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]] != 'O':
|
138 |
+
word_labels.append(
|
139 |
+
self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]][2:])
|
140 |
+
else:
|
141 |
+
word_labels.append('other')
|
142 |
+
if word_labels != []:
|
143 |
+
word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
|
144 |
+
else:
|
145 |
+
word_tagging = 'other'
|
146 |
+
word['label'] = word_tagging
|
147 |
+
word['pageSize'] = {'width': width, 'height': height}
|
148 |
+
if word['label'] != 'other':
|
149 |
+
doc_list.append(word)
|
150 |
+
spans = []
|
151 |
+
def adjacents(entity): return [
|
152 |
+
adj for adj in doc_list if adjacent(entity, adj)]
|
153 |
+
output_test_tmp = doc_list[:]
|
154 |
+
for entity in doc_list:
|
155 |
+
if adjacents(entity) == []:
|
156 |
+
spans.append([entity])
|
157 |
+
output_test_tmp.remove(entity)
|
158 |
+
|
159 |
+
while output_test_tmp != []:
|
160 |
+
span = [output_test_tmp[0]]
|
161 |
+
output_test_tmp = output_test_tmp[1:]
|
162 |
+
while output_test_tmp != [] and adjacent(span[-1], output_test_tmp[0]):
|
163 |
+
span.append(output_test_tmp[0])
|
164 |
+
output_test_tmp.remove(output_test_tmp[0])
|
165 |
+
spans.append(span)
|
166 |
+
|
167 |
+
output_spans = []
|
168 |
+
for span in spans:
|
169 |
+
if len(span) == 1:
|
170 |
+
output_span = {"text": span[0]['text'],
|
171 |
+
"label": span[0]['label'],
|
172 |
+
"words": [{
|
173 |
+
'id': span[0]['id'],
|
174 |
+
'box': span[0]['box'],
|
175 |
+
'text': span[0]['text']
|
176 |
+
}],
|
177 |
+
}
|
178 |
+
else:
|
179 |
+
output_span = {"text": ' '.join([entity['text'] for entity in span]),
|
180 |
+
"label": span[0]['label'],
|
181 |
+
"words": [{
|
182 |
+
'id': entity['id'],
|
183 |
+
'box': entity['box'],
|
184 |
+
'text': entity['text']
|
185 |
+
} for entity in span]
|
186 |
+
|
187 |
+
}
|
188 |
+
output_spans.append(output_span)
|
189 |
+
docs.append({f'output': output_spans})
|
190 |
+
return [json.dumps(docs, ensure_ascii=False)]
|
191 |
+
|
192 |
+
def handle(self, data, context):
|
193 |
+
"""
|
194 |
+
Call preprocess, inference and post-process functions
|
195 |
+
:param data: input data
|
196 |
+
:param context: mms context
|
197 |
+
"""
|
198 |
+
# print("\nmodel_input\n",data)
|
199 |
+
print("context",context)
|
200 |
+
|
201 |
+
model_input = self.preprocess(data)
|
202 |
+
print("this is model input",model_input)
|
203 |
+
model_out = self.inference(model_input)
|
204 |
+
print("\nmodel_output\n",model_out)
|
205 |
+
inference_out = self.postprocess(model_out)[0]
|
206 |
+
|
207 |
+
print("\nprocessed output\n",inference_out)
|
208 |
+
|
209 |
+
# with open('LayoutlMV3InferenceOutput.json', 'w') as inf_out:
|
210 |
+
# inf_out.write(inference_out)
|
211 |
+
inference_out_list = json.loads(inference_out)
|
212 |
+
flattened_output_list = get_flattened_output(inference_out_list)
|
213 |
+
print("flattened_output_list",flattened_output_list)
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
if self.name == "cheque":
|
218 |
+
acc_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
|
219 |
+
IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
|
220 |
+
|
221 |
+
print("entered cheque\n\n",flattened_output_list,"\n\n")
|
222 |
+
result = {"acc_num":acc_num,
|
223 |
+
"IFSC":IFSC}
|
224 |
+
if self.name == "aadhar":
|
225 |
+
aadhar_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
|
226 |
+
print("entered aadhar\n\n",flattened_output_list,"\n\n")
|
227 |
+
# IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
|
228 |
+
result = {"aadhar_num":aadhar_num}
|
229 |
+
|
230 |
+
if self.name == "pan":
|
231 |
+
pan_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'PAN_VALUE')
|
232 |
+
print("entered pan\n\n",flattened_output_list,"\n\n")
|
233 |
+
# IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
|
234 |
+
result = {"pan_num":pan_num}
|
235 |
+
if self.name == "gst":
|
236 |
+
gstin_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'GSTIN')
|
237 |
+
print("entered gst\n\n",flattened_output_list,"\n\n")
|
238 |
+
# IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
|
239 |
+
result = {"gstin_num":gstin_num}
|
240 |
+
|
241 |
+
|
242 |
+
# if
|
243 |
+
# an_tokens = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
|
244 |
+
#PAN_VALUE
|
245 |
+
#AN
|
246 |
+
#IFSC
|
247 |
+
|
248 |
+
# print(f"Concatenated AN tokens: {an_tokens}")
|
249 |
+
|
250 |
+
# print("this is flattened output",flattened_output_list)
|
251 |
+
for i, flattened_output in enumerate(flattened_output_list):
|
252 |
+
annotate_image(data['image_path'][i], flattened_output)
|
253 |
+
|
254 |
+
return result
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
_service = ModelHandler()
|
259 |
+
|
260 |
+
|
261 |
+
def handle(data, context,processor,name):
|
262 |
+
# if not _service.initialized:
|
263 |
+
_service.initialize(context,processor,name)
|
264 |
+
|
265 |
+
# if data is None:
|
266 |
+
# return None
|
267 |
+
|
268 |
+
return _service.handle(data, context)
|
layoutlmv3FineTuning/Layoutlm_inference/model_base_path.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
LAYOUTLMV2_BASE_PATH = "microsoft/layoutlmv2-base-uncased"
|
2 |
+
LAYOUTLMV3_BASE_PATH = "microsoft/layoutlmv3-base"
|
layoutlmv3FineTuning/Layoutlm_inference/ocr.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
import os
|
5 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./titanium-scope-436311-t3-966373f5aa2f.json"
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def run_tesseract_on_image(image_path): # -> tsv output path
|
11 |
+
print("image_path",image_path)
|
12 |
+
image_name = os.path.basename(image_path)
|
13 |
+
image_name = image_name[:image_name.find('.')]
|
14 |
+
error_code = os.system(f'''
|
15 |
+
tesseract "{image_path}" "/content/{image_name}" -l eng tsv
|
16 |
+
''')
|
17 |
+
if not error_code:
|
18 |
+
return f"/content/{image_name}.tsv"
|
19 |
+
else:
|
20 |
+
raise ValueError('Tesseract OCR Error please verify image format PNG,JPG,JPEG')
|
21 |
+
|
22 |
+
|
23 |
+
def clean_tesseract_output(tsv_output_path):
|
24 |
+
print("tsv_output_path",tsv_output_path)
|
25 |
+
ocr_df = pd.read_csv(tsv_output_path, sep='\t')
|
26 |
+
ocr_df = ocr_df.dropna()
|
27 |
+
ocr_df = ocr_df.drop(ocr_df[ocr_df.text.str.strip() == ''].index)
|
28 |
+
text_output = ' '.join(ocr_df.text.tolist())
|
29 |
+
words = []
|
30 |
+
for index, row in ocr_df.iterrows():
|
31 |
+
word = {}
|
32 |
+
origin_box = [row['left'], row['top'], row['left'] +
|
33 |
+
row['width'], row['top']+row['height']]
|
34 |
+
word['word_text'] = row['text']
|
35 |
+
word['word_box'] = origin_box
|
36 |
+
words.append(word)
|
37 |
+
return words
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def detect_text(path):
|
43 |
+
print("this is path:",path)
|
44 |
+
|
45 |
+
"""Detects text in the file."""
|
46 |
+
from google.cloud import vision
|
47 |
+
client = vision.ImageAnnotatorClient()
|
48 |
+
with open(path, "rb") as image_file:
|
49 |
+
content = image_file.read()
|
50 |
+
image = vision.Image(content=content)
|
51 |
+
response = client.text_detection(image=image)
|
52 |
+
texts = response.text_annotations
|
53 |
+
print("Texts:")
|
54 |
+
list_of_dict = []
|
55 |
+
for text in texts[1:]:
|
56 |
+
data_dic = {}
|
57 |
+
print(f'\n"{text.description}"')
|
58 |
+
data_dic["word_text"] = text.description
|
59 |
+
|
60 |
+
vertices_list = [[int(vertex.x),int(vertex.y)] for vertex in text.bounding_poly.vertices]
|
61 |
+
print("vertices_list",vertices_list)
|
62 |
+
|
63 |
+
|
64 |
+
coords = vertices_list
|
65 |
+
|
66 |
+
sorted_coords = sorted(coords, key=lambda coord: (coord[0] + coord[1]))
|
67 |
+
|
68 |
+
# Top-left is the first in the sorted list (smallest sum of x, y)
|
69 |
+
top_left = sorted_coords[0]
|
70 |
+
|
71 |
+
# Bottom-right is the last in the sorted list (largest sum of x, y)
|
72 |
+
bottom_right = sorted_coords[-1]
|
73 |
+
|
74 |
+
ls = []
|
75 |
+
ls.append(top_left[0])
|
76 |
+
ls.append(top_left[1])
|
77 |
+
ls.append(bottom_right[0])
|
78 |
+
ls.append(bottom_right[1])
|
79 |
+
|
80 |
+
# print(ls)
|
81 |
+
|
82 |
+
# ls = []
|
83 |
+
|
84 |
+
# ls.append(vertices_list[0][0])
|
85 |
+
# ls.append(vertices_list[0][1])
|
86 |
+
# ls.append(vertices_list[2][0])
|
87 |
+
# ls.append(vertices_list[2][1])
|
88 |
+
|
89 |
+
data_dic["word_box"] = ls
|
90 |
+
|
91 |
+
list_of_dict.append(data_dic)
|
92 |
+
|
93 |
+
if response.error.message:
|
94 |
+
raise Exception(
|
95 |
+
"{}\nFor more info on error messages, check: "
|
96 |
+
"https://cloud.google.com/apis/design/errors".format(response.error.message)
|
97 |
+
)
|
98 |
+
|
99 |
+
return list_of_dict
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def prepare_batch_for_inference(image_paths):
|
105 |
+
# tesseract_outputs is a list of paths
|
106 |
+
inference_batch = dict()
|
107 |
+
# tesseract_outputs = [run_tesseract_on_image(
|
108 |
+
# image_path) for image_path in image_paths]
|
109 |
+
|
110 |
+
# tesseract_outputs = []
|
111 |
+
# for image_path in image_paths:
|
112 |
+
|
113 |
+
# output = run_tesseract_on_image(image_path)
|
114 |
+
# tesseract_outputs.append(output)
|
115 |
+
|
116 |
+
# clean_outputs is a list of lists
|
117 |
+
# clean_outputs = [clean_tesseract_output(
|
118 |
+
# tsv_path) for tsv_path in tesseract_outputs]
|
119 |
+
|
120 |
+
# clean_outputs = []
|
121 |
+
# for tsv_path in tesseract_outputs:
|
122 |
+
# output = clean_tesseract_output(tsv_path)
|
123 |
+
# clean_outputs.append(output)
|
124 |
+
|
125 |
+
|
126 |
+
clean_outputs = []
|
127 |
+
for image_path in image_paths:
|
128 |
+
|
129 |
+
output = detect_text(image_path)
|
130 |
+
clean_outputs.append(output)
|
131 |
+
|
132 |
+
print("clean_outputs",clean_outputs)
|
133 |
+
|
134 |
+
|
135 |
+
word_lists = [[word['word_text'] for word in clean_output]
|
136 |
+
for clean_output in clean_outputs]
|
137 |
+
boxes_lists = [[word['word_box'] for word in clean_output]
|
138 |
+
for clean_output in clean_outputs]
|
139 |
+
inference_batch = {
|
140 |
+
"image_path": image_paths,
|
141 |
+
"bboxes": boxes_lists,
|
142 |
+
"words": word_lists
|
143 |
+
}
|
144 |
+
return inference_batch
|
layoutlmv3FineTuning/Layoutlm_inference/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from transformers import AutoModelForTokenClassification, AutoProcessor
|
3 |
+
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Load .env file
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
# Access variables
|
11 |
+
dummy_key = os.getenv("dummy_key")
|
12 |
+
# secret_key = os.getenv("SECRET_KEY")
|
13 |
+
# debug_mode = os.getenv("DEBUG")
|
14 |
+
|
15 |
+
# print(f"Database URL: {database_url}")
|
16 |
+
# print(f"Secret Key: {secret_key}")
|
17 |
+
# print(f"Debug Mode: {debug_mode}")
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def normalize_box(bbox, width, height):
|
23 |
+
return [
|
24 |
+
int(bbox[0]*(1000/width)),
|
25 |
+
int(bbox[1]*(1000/height)),
|
26 |
+
int(bbox[2]*(1000/width)),
|
27 |
+
int(bbox[3]*(1000/height)),
|
28 |
+
]
|
29 |
+
|
30 |
+
def compare_boxes(b1, b2):
|
31 |
+
b1 = np.array([c for c in b1])
|
32 |
+
b2 = np.array([c for c in b2])
|
33 |
+
equal = np.array_equal(b1, b2)
|
34 |
+
return equal
|
35 |
+
|
36 |
+
def unnormalize_box(bbox, width, height):
|
37 |
+
return [
|
38 |
+
width * (bbox[0] / 1000),
|
39 |
+
height * (bbox[1] / 1000),
|
40 |
+
width * (bbox[2] / 1000),
|
41 |
+
height * (bbox[3] / 1000),
|
42 |
+
]
|
43 |
+
|
44 |
+
def adjacent(w1, w2):
|
45 |
+
if w1['label'] == w2['label'] and abs(w1['id'] - w2['id']) == 1:
|
46 |
+
return True
|
47 |
+
return False
|
48 |
+
|
49 |
+
def random_color():
|
50 |
+
return np.random.randint(0, 255, 3)
|
51 |
+
|
52 |
+
def image_label_2_color(annotation):
|
53 |
+
if 'output' in annotation.keys():
|
54 |
+
image_labels = set([span['label'] for span in annotation['output']])
|
55 |
+
label2color = {f'{label}': (random_color()[0], random_color()[
|
56 |
+
1], random_color()[2]) for label in image_labels}
|
57 |
+
return label2color
|
58 |
+
else:
|
59 |
+
raise ValueError('please use "output" as annotation key')
|
60 |
+
|
61 |
+
def load_model(model_path):
|
62 |
+
model = AutoModelForTokenClassification.from_pretrained(model_path,use_auth_token=dummy_key)
|
63 |
+
return model
|
64 |
+
|
65 |
+
def load_processor(model_name_or_path):
|
66 |
+
processor = AutoProcessor.from_pretrained(
|
67 |
+
model_name_or_path, apply_ocr=False,use_auth_token=dummy_key)
|
68 |
+
return processor
|
layoutlmv3FineTuning/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# layoutlmFineTuning
|
2 |
+
this repo aims to train a layoutlmv3 model using __ubiai__ ocr annotated dataset with a preprocess and train scripts and then test the model via inference script
|
3 |
+
* Note that the provided inference Module support both Layoutlmv3 and Layoutlmv2 models
|
layoutlmv3FineTuning/inference_handler_modified.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import load_model,load_processor,normalize_box,compare_boxes,adjacent
|
2 |
+
from .model_base_path import LAYOUTLMV2_BASE_PATH,LAYOUTLMV3_BASE_PATH
|
3 |
+
from .annotate_image import get_flattened_output,annotate_image
|
4 |
+
from PIL import Image,ImageDraw, ImageFont
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class ModelHandler(object):
|
14 |
+
"""
|
15 |
+
A base Model handler implementation.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.model = None
|
20 |
+
self.model_dir = None
|
21 |
+
self.device = 'cpu'
|
22 |
+
self.error = None
|
23 |
+
# self._context = None
|
24 |
+
# self._batch_size = 0
|
25 |
+
self.initialized = False
|
26 |
+
self._raw_input_data = None
|
27 |
+
self._processed_data = None
|
28 |
+
self._images_size = None
|
29 |
+
|
30 |
+
def initialize(self, context):
|
31 |
+
"""
|
32 |
+
Initialize model. This will be called during model loading time
|
33 |
+
:param context: Initial context contains model server system properties.
|
34 |
+
:return:
|
35 |
+
"""
|
36 |
+
logger.info("Loading transformer model")
|
37 |
+
|
38 |
+
self._context = context
|
39 |
+
properties = self._context
|
40 |
+
# self._batch_size = properties["batch_size"] or 1
|
41 |
+
self.model_dir = properties.get("model_dir")
|
42 |
+
self.model = self.load(self.model_dir)
|
43 |
+
self.initialized = True
|
44 |
+
|
45 |
+
def preprocess(self, batch):
|
46 |
+
"""
|
47 |
+
Transform raw input into model input data.
|
48 |
+
:param batch: list of raw requests, should match batch size
|
49 |
+
:return: list of preprocessed model input data
|
50 |
+
"""
|
51 |
+
# Take the input data and pre-process it make it inference ready
|
52 |
+
# assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
|
53 |
+
inference_dict = batch
|
54 |
+
self._raw_input_data = inference_dict
|
55 |
+
model_name_or_path = None
|
56 |
+
if 'v2' in self.model.config.architectures[0]:
|
57 |
+
model_name_or_path = LAYOUTLMV2_BASE_PATH
|
58 |
+
elif 'v3' in self.model.config.architectures[0]:
|
59 |
+
model_name_or_path = LAYOUTLMV3_BASE_PATH
|
60 |
+
else:
|
61 |
+
raise ValueError('invalid model architecture, please make sure the model is either Layoutlmv2 or Layoutlmv3')
|
62 |
+
processor = load_processor(model_name_or_path)
|
63 |
+
images = [Image.open(path).convert("RGB")
|
64 |
+
for path in inference_dict['image_path']]
|
65 |
+
self._images_size = [img.size for img in images]
|
66 |
+
words = inference_dict['words']
|
67 |
+
boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
|
68 |
+
for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]
|
69 |
+
encoded_inputs = processor(
|
70 |
+
images, words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
|
71 |
+
self._processed_data = encoded_inputs
|
72 |
+
return encoded_inputs
|
73 |
+
|
74 |
+
def load(self, model_dir):
|
75 |
+
"""The load handler is responsible for loading the hunggingface transformer model.
|
76 |
+
Returns:
|
77 |
+
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
|
78 |
+
"""
|
79 |
+
# TODO model dir should be microsoft/layoutlmv2-base-uncased
|
80 |
+
model = load_model(model_dir)
|
81 |
+
return model
|
82 |
+
|
83 |
+
def inference(self, model_input):
|
84 |
+
"""
|
85 |
+
Internal inference methods
|
86 |
+
:param model_input: transformed model input data
|
87 |
+
:return: list of inference output in NDArray
|
88 |
+
"""
|
89 |
+
# TODO load the model state_dict before running the inference
|
90 |
+
# Do some inference call to engine here and return output
|
91 |
+
with torch.no_grad():
|
92 |
+
inference_outputs = self.model(**model_input)
|
93 |
+
predictions = inference_outputs.logits.argmax(-1).tolist()
|
94 |
+
results = []
|
95 |
+
for i in range(len(predictions)):
|
96 |
+
tmp = dict()
|
97 |
+
tmp[f'output_{i}'] = predictions[i]
|
98 |
+
results.append(tmp)
|
99 |
+
|
100 |
+
return [results]
|
101 |
+
|
102 |
+
def postprocess(self, inference_output):
|
103 |
+
docs = []
|
104 |
+
k = 0
|
105 |
+
for page, doc_words in enumerate(self._raw_input_data['words']):
|
106 |
+
doc_list = []
|
107 |
+
width, height = self._images_size[page]
|
108 |
+
for i, doc_word in enumerate(doc_words, start=0):
|
109 |
+
word_tagging = None
|
110 |
+
word_labels = []
|
111 |
+
word = dict()
|
112 |
+
word['id'] = k
|
113 |
+
k += 1
|
114 |
+
word['text'] = doc_word
|
115 |
+
word['pageNum'] = page + 1
|
116 |
+
word['box'] = self._raw_input_data['bboxes'][page][i]
|
117 |
+
_normalized_box = normalize_box(
|
118 |
+
self._raw_input_data['bboxes'][page][i], width, height)
|
119 |
+
for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
|
120 |
+
if compare_boxes(box, _normalized_box):
|
121 |
+
if self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]] != 'O':
|
122 |
+
word_labels.append(
|
123 |
+
self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]][2:])
|
124 |
+
else:
|
125 |
+
word_labels.append('other')
|
126 |
+
if word_labels != []:
|
127 |
+
word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
|
128 |
+
else:
|
129 |
+
word_tagging = 'other'
|
130 |
+
word['label'] = word_tagging
|
131 |
+
word['pageSize'] = {'width': width, 'height': height}
|
132 |
+
if word['label'] != 'other':
|
133 |
+
doc_list.append(word)
|
134 |
+
spans = []
|
135 |
+
def adjacents(entity): return [
|
136 |
+
adj for adj in doc_list if adjacent(entity, adj)]
|
137 |
+
output_test_tmp = doc_list[:]
|
138 |
+
for entity in doc_list:
|
139 |
+
if adjacents(entity) == []:
|
140 |
+
spans.append([entity])
|
141 |
+
output_test_tmp.remove(entity)
|
142 |
+
|
143 |
+
while output_test_tmp != []:
|
144 |
+
span = [output_test_tmp[0]]
|
145 |
+
output_test_tmp = output_test_tmp[1:]
|
146 |
+
while output_test_tmp != [] and adjacent(span[-1], output_test_tmp[0]):
|
147 |
+
span.append(output_test_tmp[0])
|
148 |
+
output_test_tmp.remove(output_test_tmp[0])
|
149 |
+
spans.append(span)
|
150 |
+
|
151 |
+
output_spans = []
|
152 |
+
label_to_span_map = {}
|
153 |
+
|
154 |
+
for span in spans:
|
155 |
+
label = span[0]['label']
|
156 |
+
if label in label_to_span_map:
|
157 |
+
# If the label already exists, merge the current span with the existing span
|
158 |
+
existing_span = label_to_span_map[label]
|
159 |
+
existing_span["text"] += ' ' + ' '.join([entity['text'] for entity in span])
|
160 |
+
existing_span["words"].extend([{
|
161 |
+
'id': entity['id'],
|
162 |
+
'box': entity['box'],
|
163 |
+
'text': entity['text']
|
164 |
+
} for entity in span])
|
165 |
+
else:
|
166 |
+
# Create a new span for this label if it doesn't exist
|
167 |
+
output_span = {
|
168 |
+
"text": ' '.join([entity['text'] for entity in span]),
|
169 |
+
"label": label,
|
170 |
+
"words": [{
|
171 |
+
'id': entity['id'],
|
172 |
+
'box': entity['box'],
|
173 |
+
'text': entity['text']
|
174 |
+
} for entity in span]
|
175 |
+
}
|
176 |
+
label_to_span_map[label] = output_span
|
177 |
+
|
178 |
+
# Convert label_to_span_map to output_spans
|
179 |
+
output_spans = list(label_to_span_map.values())
|
180 |
+
docs.append({f'output': output_spans})
|
181 |
+
return [json.dumps(docs, ensure_ascii=False)]
|
182 |
+
|
183 |
+
def handle(self, data, context):
|
184 |
+
"""
|
185 |
+
Call preprocess, inference and post-process functions
|
186 |
+
:param data: input data
|
187 |
+
:param context: mms context
|
188 |
+
"""
|
189 |
+
model_input = self.preprocess(data)
|
190 |
+
model_out = self.inference(model_input)
|
191 |
+
inference_out = self.postprocess(model_out)[0]
|
192 |
+
import os
|
193 |
+
print("cwd",os.getcwd())
|
194 |
+
with open('LayoutlMV3InferenceOutput.json', 'w') as inf_out:
|
195 |
+
inf_out.write(inference_out)
|
196 |
+
inference_out_list = json.loads(inference_out)
|
197 |
+
flattened_output_list = get_flattened_output(inference_out_list)
|
198 |
+
for i, flattened_output in enumerate(flattened_output_list):
|
199 |
+
annotate_image(data['image_path'][i], flattened_output)
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
_service = ModelHandler()
|
204 |
+
|
205 |
+
|
206 |
+
def handle(data, context):
|
207 |
+
if not _service.initialized:
|
208 |
+
_service.initialize(context)
|
209 |
+
|
210 |
+
if data is None:
|
211 |
+
return None
|
212 |
+
|
213 |
+
return _service.handle(data, context)
|
layoutlmv3FineTuning/preprocess.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from datasets.features import ClassLabel
|
6 |
+
from transformers import AutoProcessor
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D, Dataset
|
9 |
+
from datasets import Image as Img
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
warnings.filterwarnings('ignore')
|
14 |
+
|
15 |
+
|
16 |
+
def read_text_file(file_path):
|
17 |
+
with open(file_path, 'r') as f:
|
18 |
+
return (f.readlines())
|
19 |
+
|
20 |
+
|
21 |
+
def prepare_examples(examples):
|
22 |
+
images = examples[image_column_name]
|
23 |
+
words = examples[text_column_name]
|
24 |
+
boxes = examples[boxes_column_name]
|
25 |
+
word_labels = examples[label_column_name]
|
26 |
+
|
27 |
+
encoding = processor(images, words, boxes=boxes, word_labels=word_labels,
|
28 |
+
truncation=True, padding="max_length")
|
29 |
+
|
30 |
+
return encoding
|
31 |
+
|
32 |
+
def get_zip_dir_name():
|
33 |
+
try:
|
34 |
+
os.chdir('/content/data')
|
35 |
+
dir_list = os.listdir()
|
36 |
+
any_file_name = dir_list[0]
|
37 |
+
zip_dir_name = any_file_name[:any_file_name.find('\\')]
|
38 |
+
if all(list(map(lambda x: x.startswith(zip_dir_name), dir_list))):
|
39 |
+
return zip_dir_name
|
40 |
+
return False
|
41 |
+
finally:
|
42 |
+
os.chdir('./../')
|
43 |
+
|
44 |
+
|
45 |
+
def filter_out_unannotated(example):
|
46 |
+
tags = example['ner_tags']
|
47 |
+
return not all([tag == label2id['O'] for tag in tags])
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument('--valid_size')
|
55 |
+
parser.add_argument('--output_path')
|
56 |
+
args = parser.parse_args()
|
57 |
+
TEST_SIZE = float(args.valid_size)
|
58 |
+
OUTPUT_PATH = args.output_path
|
59 |
+
|
60 |
+
os.makedirs(args.output_path, exist_ok=True)
|
61 |
+
files = {}
|
62 |
+
zip_dir_name = get_zip_dir_name()
|
63 |
+
if zip_dir_name:
|
64 |
+
files['train_box'] = read_text_file(os.path.join(
|
65 |
+
os.curdir, 'data', f'{zip_dir_name}\\{zip_dir_name}_box.txt'))
|
66 |
+
files['train_image'] = read_text_file(os.path.join(
|
67 |
+
os.curdir, 'data', f'{zip_dir_name}\\{zip_dir_name}_image.txt'))
|
68 |
+
files['train'] = read_text_file(os.path.join(
|
69 |
+
os.curdir, 'data', f'{zip_dir_name}\\{zip_dir_name}.txt'))
|
70 |
+
else:
|
71 |
+
for f in os.listdir():
|
72 |
+
if f.endswith('.txt') and f.find('box') != -1:
|
73 |
+
files['train_box'] = read_text_file(os.path.join(os.curdir, f))
|
74 |
+
elif f.endswith('.txt') and f.find('image') != -1:
|
75 |
+
files['train_image'] = read_text_file(
|
76 |
+
os.path.join(os.curdir, f))
|
77 |
+
elif f.endswith('.txt') and f.find('labels') == -1:
|
78 |
+
files['train'] = read_text_file(os.path.join(os.curdir, f))
|
79 |
+
|
80 |
+
assert(len(files['train']) == len(files['train_box']))
|
81 |
+
assert(len(files['train_box']) == len(files['train_image']))
|
82 |
+
assert(len(files['train_image']) == len(files['train']))
|
83 |
+
|
84 |
+
images = {}
|
85 |
+
for i, row in enumerate(files['train_image']):
|
86 |
+
if row != '\n':
|
87 |
+
image_name = row.split('\t')[-1]
|
88 |
+
images.setdefault(image_name.replace('\n', ''), []).append(i)
|
89 |
+
|
90 |
+
words, bboxes, ner_tags, image_path = [], [], [], []
|
91 |
+
for image, rows in images.items():
|
92 |
+
words.append([row.split('\t')[0].replace('\n', '')
|
93 |
+
for row in files['train'][rows[0]:rows[-1]+1]])
|
94 |
+
ner_tags.append([row.split('\t')[1].replace('\n', '')
|
95 |
+
for row in files['train'][rows[0]:rows[-1]+1]])
|
96 |
+
bboxes.append([box.split('\t')[1].replace('\n', '')
|
97 |
+
for box in files['train_box'][rows[0]:rows[-1]+1]])
|
98 |
+
if zip_dir_name:
|
99 |
+
image_path.append(f"/content/data/{zip_dir_name}\\{image}")
|
100 |
+
else:
|
101 |
+
image_path.append(f"/content/data/{image}")
|
102 |
+
|
103 |
+
labels = list(set([tag for doc_tag in ner_tags for tag in doc_tag]))
|
104 |
+
id2label = {v: k for v, k in enumerate(labels)}
|
105 |
+
label2id = {k: v for v, k in enumerate(labels)}
|
106 |
+
|
107 |
+
dataset_dict = {
|
108 |
+
'id': range(len(words)),
|
109 |
+
'tokens': words,
|
110 |
+
'bboxes': [[list(map(int, bbox.split())) for bbox in doc] for doc in bboxes],
|
111 |
+
'ner_tags': [[label2id[tag] for tag in ner_tag] for ner_tag in ner_tags],
|
112 |
+
'image': [Image.open(path).convert("RGB") for path in image_path]
|
113 |
+
}
|
114 |
+
|
115 |
+
#raw features
|
116 |
+
features = Features({
|
117 |
+
'id': Value(dtype='string', id=None),
|
118 |
+
'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
119 |
+
'bboxes': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
|
120 |
+
'ner_tags': Sequence(feature=ClassLabel(num_classes=len(labels), names=labels, names_file=None, id=None), length=-1, id=None),
|
121 |
+
'image': Img(decode=True, id=None)
|
122 |
+
})
|
123 |
+
|
124 |
+
full_data_set = Dataset.from_dict(dataset_dict, features=features)
|
125 |
+
dataset = full_data_set.train_test_split(test_size=TEST_SIZE)
|
126 |
+
dataset["train"] = dataset["train"].filter(filter_out_unannotated)
|
127 |
+
processor = AutoProcessor.from_pretrained(
|
128 |
+
"microsoft/layoutlmv3-base", apply_ocr=False)
|
129 |
+
|
130 |
+
features = dataset["train"].features
|
131 |
+
column_names = dataset["train"].column_names
|
132 |
+
image_column_name = "image"
|
133 |
+
text_column_name = "tokens"
|
134 |
+
boxes_column_name = "bboxes"
|
135 |
+
label_column_name = "ner_tags"
|
136 |
+
|
137 |
+
# we need to define custom features for `set_format` (used later on) to work properly
|
138 |
+
features = Features({
|
139 |
+
'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
|
140 |
+
'input_ids': Sequence(feature=Value(dtype='int64')),
|
141 |
+
'attention_mask': Sequence(Value(dtype='int64')),
|
142 |
+
'bbox': Array2D(dtype="int64", shape=(512, 4)),
|
143 |
+
'labels': Sequence(ClassLabel(names=labels)),
|
144 |
+
})
|
145 |
+
|
146 |
+
train_dataset = dataset["train"].map(
|
147 |
+
prepare_examples,
|
148 |
+
batched=True,
|
149 |
+
remove_columns=column_names,
|
150 |
+
features=features,
|
151 |
+
)
|
152 |
+
eval_dataset = dataset["test"].map(
|
153 |
+
prepare_examples,
|
154 |
+
batched=True,
|
155 |
+
remove_columns=column_names,
|
156 |
+
features=features,
|
157 |
+
)
|
158 |
+
train_dataset.set_format("torch")
|
159 |
+
if not OUTPUT_PATH.endswith('/'):
|
160 |
+
OUTPUT_PATH += '/'
|
161 |
+
train_dataset.save_to_disk(f'{OUTPUT_PATH}train_split')
|
162 |
+
eval_dataset.save_to_disk(f'{OUTPUT_PATH}eval_split')
|
163 |
+
dataset.save_to_disk(f'{OUTPUT_PATH}raw_data')
|
layoutlmv3FineTuning/run_inference.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from asyncio.log import logger
|
3 |
+
from Layoutlm_inference.ocr import prepare_batch_for_inference
|
4 |
+
from Layoutlm_inference.inference_handler import handle
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
# try:
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--model_path", type=str)
|
12 |
+
parser.add_argument("--images_path", type=str)
|
13 |
+
args, _ = parser.parse_known_args()
|
14 |
+
images_path = args.images_path
|
15 |
+
image_files = os.listdir(images_path)
|
16 |
+
images_path = [images_path+f'/{image_file}' for image_file in image_files]
|
17 |
+
inference_batch = prepare_batch_for_inference(images_path)
|
18 |
+
context = {"model_dir": args.model_path}
|
19 |
+
output_ls = handle(inference_batch,context)
|
20 |
+
|
21 |
+
print("output_ls",output_ls)
|
22 |
+
|
23 |
+
|
24 |
+
# except Exception as err:
|
25 |
+
# os.makedirs('log', exist_ok=True)
|
26 |
+
# logging.basicConfig(filename='log/error_output.log', level=logging.ERROR,
|
27 |
+
# format='%(asctime)s %(levelname)s %(name)s %(message)s')
|
28 |
+
# logger = logging.getLogger(__name__)
|
29 |
+
# logger.error(err)
|
30 |
+
|
31 |
+
|
layoutlmv3FineTuning/run_inferenceM.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from asyncio.log import logger
|
3 |
+
from Layoutlm_inference.ocr import prepare_batch_for_inference
|
4 |
+
from Layoutlm_inference.inference_handler import handle
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
# try:
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--model_path", type=str, required=True)
|
14 |
+
parser.add_argument("--image_path", type=str, required=True) # single image path
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
# Expecting a single image file
|
18 |
+
image_path = args.image_path
|
19 |
+
|
20 |
+
# Ensure the file exists before processing
|
21 |
+
if not os.path.isfile(image_path):
|
22 |
+
raise FileNotFoundError(f"The provided image path does not exist: {image_path}")
|
23 |
+
|
24 |
+
# Prepare batch for a single image
|
25 |
+
inference_batch = prepare_batch_for_inference([image_path]) # pass as a list
|
26 |
+
context = {"model_dir": args.model_path}
|
27 |
+
|
28 |
+
# Handle the inference
|
29 |
+
handle(inference_batch, context)
|
30 |
+
|
31 |
+
|
32 |
+
|
multiple_request.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import concurrent.futures
|
3 |
+
import time
|
4 |
+
|
5 |
+
# Define the API endpoint
|
6 |
+
#http://43.204.234.114:8000/api/aadhar_ocr
|
7 |
+
# API_URL = "http://127.0.0.1:8000/api/aadhar_ocr"
|
8 |
+
API_URL = "http://localhost:8000/api/aadhar_ocr"
|
9 |
+
|
10 |
+
|
11 |
+
# Define the file paths
|
12 |
+
FILE_PATHS = {
|
13 |
+
"aadhar_file": "uploads/aadhar/test_one.jpg",
|
14 |
+
# "pan_file": "test_images_pan/6ea33087.jpeg",
|
15 |
+
# "cheque_file": "test_images_cheque/0f81678a.jpeg",
|
16 |
+
# "gst_file": "test_images_gst/0a52fbcb_page3_image_0.jpg",
|
17 |
+
}
|
18 |
+
|
19 |
+
# Function to send a single POST request
|
20 |
+
def send_request():
|
21 |
+
try:
|
22 |
+
start_time = time.time()
|
23 |
+
# Open files dynamically for each request
|
24 |
+
files = {key: open(path, "rb") for key, path in FILE_PATHS.items()}
|
25 |
+
response = requests.post(API_URL, files=files)
|
26 |
+
print("this is response\n\n",response)
|
27 |
+
end_time = time.time()
|
28 |
+
print(f"\nTime taken for one request: {end_time - start_time:.2f} seconds")
|
29 |
+
# Close the files after the request
|
30 |
+
for file in files.values():
|
31 |
+
file.close()
|
32 |
+
return response.status_code, response.text
|
33 |
+
except requests.exceptions.RequestException as e:
|
34 |
+
return "Error", str(e)
|
35 |
+
|
36 |
+
# Main function to send multiple concurrent requests
|
37 |
+
def test_api_concurrency(num_requests):
|
38 |
+
start_time = time.time()
|
39 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
40 |
+
# Launch multiple requests concurrently
|
41 |
+
results = list(executor.map(lambda _: send_request(), range(num_requests)))
|
42 |
+
end_time = time.time()
|
43 |
+
|
44 |
+
# Print results
|
45 |
+
for idx, (status, text) in enumerate(results):
|
46 |
+
print(f"Request {idx + 1}: Status Code: {status}, Response: {text}")
|
47 |
+
|
48 |
+
print(f"\nTotal time taken: {end_time - start_time:.2f} seconds")
|
49 |
+
|
50 |
+
# Number of concurrent requests
|
51 |
+
NUM_REQUESTS = 8 # Adjust this number based on your testing needs
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
test_api_concurrency(NUM_REQUESTS)
|
55 |
+
|
56 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
python-multipart
|
4 |
+
git+https://github.com/huggingface/transformers.git
|
5 |
+
git+https://github.com/huggingface/datasets.git
|
6 |
+
transformers[torch]
|
7 |
+
pillow
|
8 |
+
google-cloud-vision
|
9 |
+
python-dotenv
|
sample.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
# Define the API endpoint
|
4 |
+
# url = "http://127.0.0.0:7860/api/home"
|
5 |
+
|
6 |
+
post_url = "http://localhost:7860/api/aadhar_ocr"
|
7 |
+
# response = requests.get(url)
|
8 |
+
|
9 |
+
# print()
|
10 |
+
# Define the file pathscd
|
11 |
+
files = {
|
12 |
+
"aadhar_file": open("/home/javmulla/model_one/test_images_aadhar/test_two.jpg", "rb"),
|
13 |
+
"pan_file": open("/home/javmulla/model_one/test_images_pan/6ea33087.jpeg", "rb"),
|
14 |
+
"cheque_file": open("/home/javmulla/model_one/test_images_cheque/0f81678a.jpeg", "rb"),
|
15 |
+
"gst_file": open("/home/javmulla/model_one/test_images_gst/0a52fbcb_page3_image_0.jpg", "rb"),
|
16 |
+
}
|
17 |
+
|
18 |
+
response = requests.post(post_url, files=files)
|
19 |
+
|
20 |
+
# # Print the response
|
21 |
+
print("Status Code:", response.status_code)
|
22 |
+
print("Response Text:", response.text)
|
titanium-scope-436311-t3-966373f5aa2f.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"type": "service_account",
|
3 |
+
"project_id": "titanium-scope-436311-t3",
|
4 |
+
"private_key_id": "966373f5aa2f27bb48fee7cd9d4afd6b1b432387",
|
5 |
+
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCUMbbv+N8zoYiY\nBECTcq6vZR/biV5sYlToXujzDw5iYHtMAPX5V6Z2ORgOuvq4WTAmozwKG2LLrapr\ntFKqeKBAhQR79Jlrek6efYwIgI/PVtDNhvylBg7ZfINX5HVE9tHxSS66jKEQwBmq\nW4tILP3BvkxDv0FhjkKtO9D1tm3omQIhg7B0+0T9IjGlQ/8Y67NMHDuP4MWTP2r6\njU9ulYp3r10ZSc+jZHX3jXA0UCM5LehYorZb3/GEldMdKvZ4RJvMDaolbu6aE9zY\nTkrVuzo6uNJgv1h+FYTjvnbjT2AYq0H4KcLXgQZVf0F72ibmisjMqA0XehIrTbhY\nbEGe26oZAgMBAAECggEAPhdSYdtxkY110NO/Rsg/PsftACvfPyQ4FSBnFCfTzA5G\nusKQTQeXfGNRnCJlmEXuMdIk/ssYquQ5ymTEWh6ubjoNde43NdwKAsfxm0JafvIO\nDH8pbe9K238a/QGAzQNpVWJnTMxNU9pZJpKymewX6kxUYfJJb5mOgEzWsYzdIh4O\nl1XuylR2m0OK+NgAqhuFvqFkRqem6tlfDhGl+dIQNZ60OVXew0xEMV6x/z1OYTqR\nS+S0GUcfZB6OVIv/anKZ8s49noBuR/JkMX2sIaXCTcicL0o44n2ROUw0jcxKxi+6\nv8IQNcxm28b9SbPNgxb4KCdCOqF9iePcLLLr4S/QHQKBgQDEwi51FynUzCm4sEVE\ndEa32xkrkYe/gZxg4kTH5Sn4Ts1Xidg2z1HreaxTM3Nomu+2PUWrFu7n2YwAkvdx\nrWGoegRZNNTlga8yFid24BqjeszS/hO5Fg7PGN+beDcp63NVirYTkQjZ8FM6UA50\nKZ8c6Qyt/bGaihDUWsXdbuTobwKBgQDA0EatcxAW3sl7f6Mw82wHiC2mmMF7g+f0\ntC0B7xirrf9TSSXDSwYxWUJ6rAxTjoskjmi+lBw+XAIDz5bWPkuwya0zeOdCh9yp\ndEvv3pm8puPzwFNLh7OWyROW3cmV5C1tLGqdGyYr7WkHGXAZCkc9U5wFQY683j5o\n3b7skCSJ9wKBgCH7g7iXapMlO+N5Fk2PY5NnlP5QYUizIwYcrlJ0Av6u5YpD9YLp\n5bUsy5WHIlyjvdkU1g6JpHOIwERtHa2Vi3Nkt5GMrWSCNHcLGn/OjutDT1L1rQRf\niek824nnhmeIEeBpV68jcorplgZRQ13OvntoyNbYJS+SvvtePiRTfdejAoGATqQk\nT5ZAl7NiZjaW7t45z5ChXfOr5p7UOqBKQyGr5Enhe6y39EFjUzlevf3yQRpAcjaL\nTj/GjUClqbw/fz6FTKPVOsszN5WGUK8YUctu1N0U2FQ3JPVCMFvu23e2QqaASKj3\nCwEJvpzkW3rql6vzhnXViudEOpBC0C6xMndQD90CgYAQYkIA1O5aGjT8TMhxa1l+\ng0/mxmeVowUUzG1Yntr3LQKQE/+tRZ/gE9N57PutecYIrWmOGPhEF89LF4nI1uqF\nWEdatBSr80jcHL8LiNPdVHwa8G3AH7MoR5Tq7RzYyDKxJKtf9IZJZ9dMve0iVb1R\n5J6yvGjzRzs7xnG3lvUpYQ==\n-----END PRIVATE KEY-----\n",
|
6 |
+
"client_email": "vision-service@titanium-scope-436311-t3.iam.gserviceaccount.com",
|
7 |
+
"client_id": "105182341558314183890",
|
8 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
9 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
10 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
11 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/vision-service%40titanium-scope-436311-t3.iam.gserviceaccount.com",
|
12 |
+
"universe_domain": "googleapis.com"
|
13 |
+
}
|
uploads/aadhar/test_one.jpg
ADDED
uploads/aadhar/test_two.jpg
ADDED
uploads/cheque/0f81678a.jpeg
ADDED
uploads/gst/0a52fbcb_page3_image_0.jpg
ADDED
uploads/pan/6ea33087.jpeg
ADDED