longjava2024 commited on
Commit
d78e827
·
verified ·
1 Parent(s): f3d290d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import json
4
+ import ast
5
+ import re
6
+ from io import BytesIO
7
+ import types
8
+ import sys
9
+
10
+ # Force CPU-only & disable bitsandbytes CUDA checks in this environment
11
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
12
+ os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
13
+ os.environ.setdefault("BITSANDBYTES_DISABLE_CUDA_CHECK", "1")
14
+
15
+ import torch
16
+ import torchvision.transforms as T
17
+ from PIL import Image
18
+ from torchvision.transforms.functional import InterpolationMode
19
+ import gradio as gr
20
+
21
+ # Stub bitsandbytes and flash_attn to avoid GPU driver checks in CPU-only environments
22
+ fake_bnb = types.ModuleType("bitsandbytes")
23
+ def _bnb_unavailable(*args, **kwargs):
24
+ raise ImportError("bitsandbytes is not available in this CPU-only deployment")
25
+ fake_bnb.__all__ = ["_bnb_unavailable"]
26
+ fake_bnb._bnb_unavailable = _bnb_unavailable
27
+ sys.modules["bitsandbytes"] = fake_bnb
28
+
29
+ fake_flash = types.ModuleType("flash_attn")
30
+ sys.modules["flash_attn"] = fake_flash
31
+
32
+ from transformers import AutoModel, AutoTokenizer
33
+
34
+
35
+ MODEL_NAME = "5CD-AI/Vintern-1B-v2"
36
+ DEVICE = "cpu"
37
+ DTYPE = torch.float32
38
+
39
+ print(f"Loading model `{MODEL_NAME}` on {DEVICE} ...")
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ MODEL_NAME,
42
+ trust_remote_code=True,
43
+ use_fast=False,
44
+ )
45
+ model = AutoModel.from_pretrained(
46
+ MODEL_NAME,
47
+ torch_dtype=DTYPE,
48
+ low_cpu_mem_usage=True,
49
+ trust_remote_code=True,
50
+ )
51
+ model.eval().to(DEVICE)
52
+
53
+ generation_config = dict(
54
+ max_new_tokens=512,
55
+ do_sample=False,
56
+ num_beams=3,
57
+ repetition_penalty=3.5,
58
+ )
59
+
60
+
61
+ # =========================
62
+ # Image preprocessing (from notebook)
63
+ # =========================
64
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
65
+ IMAGENET_STD = (0.229, 0.224, 0.225)
66
+
67
+
68
+ def build_transform(input_size: int):
69
+ mean, std = IMAGENET_MEAN, IMAGENET_STD
70
+ transform = T.Compose(
71
+ [
72
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
73
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
74
+ T.ToTensor(),
75
+ T.Normalize(mean=mean, std=std),
76
+ ]
77
+ )
78
+ return transform
79
+
80
+
81
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
82
+ best_ratio_diff = float("inf")
83
+ best_ratio = (1, 1)
84
+ area = width * height
85
+ for ratio in target_ratios:
86
+ target_aspect_ratio = ratio[0] / ratio[1]
87
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
88
+ if ratio_diff < best_ratio_diff:
89
+ best_ratio_diff = ratio_diff
90
+ best_ratio = ratio
91
+ elif ratio_diff == best_ratio_diff:
92
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
93
+ best_ratio = ratio
94
+ return best_ratio
95
+
96
+
97
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
98
+ orig_width, orig_height = image.size
99
+ aspect_ratio = orig_width / orig_height
100
+
101
+ target_ratios = set(
102
+ (i, j)
103
+ for n in range(min_num, max_num + 1)
104
+ for i in range(1, n + 1)
105
+ for j in range(1, n + 1)
106
+ if i * j <= max_num and i * j >= min_num
107
+ )
108
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
109
+
110
+ target_aspect_ratio = find_closest_aspect_ratio(
111
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
112
+ )
113
+
114
+ target_width = image_size * target_aspect_ratio[0]
115
+ target_height = image_size * target_aspect_ratio[1]
116
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
117
+
118
+ resized_img = image.resize((target_width, target_height))
119
+ processed_images = []
120
+ for i in range(blocks):
121
+ box = (
122
+ (i % (target_width // image_size)) * image_size,
123
+ (i // (target_width // image_size)) * image_size,
124
+ ((i % (target_width // image_size)) + 1) * image_size,
125
+ ((i // (target_width // image_size)) + 1) * image_size,
126
+ )
127
+ split_img = resized_img.crop(box)
128
+ processed_images.append(split_img)
129
+ assert len(processed_images) == blocks
130
+ if use_thumbnail and len(processed_images) != 1:
131
+ thumbnail_img = image.resize((image_size, image_size))
132
+ processed_images.append(thumbnail_img)
133
+ return processed_images
134
+
135
+
136
+ def load_image_from_base64(base64_string: str, input_size=448, max_num=12):
137
+ if base64_string.startswith("data:image"):
138
+ base64_string = base64_string.split(",", 1)[1]
139
+
140
+ image_data = base64.b64decode(base64_string)
141
+ image = Image.open(BytesIO(image_data)).convert("RGB")
142
+ transform = build_transform(input_size=input_size)
143
+ images = dynamic_preprocess(
144
+ image, image_size=input_size, use_thumbnail=True, max_num=max_num
145
+ )
146
+ pixel_values = [transform(img) for img in images]
147
+ pixel_values = torch.stack(pixel_values)
148
+ return pixel_values
149
+
150
+
151
+ # =========================
152
+ # Prompt & helpers
153
+ # =========================
154
+ PROMPT = """<image>
155
+ Bạn là hệ thống OCR + trích xuất dữ liệu từ ảnh Căn cước công dân (CCCD) Việt Nam.
156
+ Nhiệm vụ: đọc đúng chữ trên thẻ và trả về CHỈ 1 đối tượng JSON theo schema quy định.
157
+
158
+ QUY TẮC BẮT BUỘC:
159
+ 1) Chỉ trả về JSON thuần (không markdown, không giải thích, không thêm ký tự nào ngoài JSON).
160
+ 2) Chỉ được có đúng 5 khóa sau (đúng chính tả, đúng chữ thường, có dấu gạch dưới):
161
+ - "so_no"
162
+ - "ho_va_ten"
163
+ - "ngay_sinh"
164
+ - "que_quan"
165
+ - "noi_thuong_tru"
166
+ Không được thêm bất kỳ khóa nào khác.
167
+ 3) Mapping trường (lấy theo NHÃN in trên thẻ, không lấy từ QR):
168
+ - so_no: lấy giá trị ngay sau nhãn "Số / No." (hoặc "Số/No.").
169
+ - ho_va_ten: lấy giá trị ngay sau nhãn "Họ và tên / Full name".
170
+ - ngay_sinh: lấy giá trị ngay sau nhãn "Ngày sinh / Date of birth"; nếu có định dạng dd/mm/yyyy thì giữ đúng dd/mm/yyyy.
171
+ - que_quan: lấy giá trị ngay sau nhãn "Quê quán / Place of origin".
172
+ - noi_thuong_tru: lấy giá trị ngay sau nhãn "Nơi thường trú / Place of residence".
173
+ 4) Nếu trường nào không đọc được rõ/chắc chắn: đặt null. Không được suy đoán.
174
+ 5) Chuẩn hoá: trim khoảng trắng đầu/cuối; giữ nguyên dấu tiếng Việt và chữ hoa/thường như trong ảnh.
175
+
176
+ CHỈ TRẢ VỀ THEO MẪU JSON NÀY:
177
+ {
178
+ "so_no": "... hoặc null",
179
+ "ho_va_ten": "... hoặc null",
180
+ "ngay_sinh": "... hoặc null",
181
+ "que_quan": "... hoặc null",
182
+ "noi_thuong_tru": "... hoặc null"
183
+ }
184
+ """
185
+
186
+
187
+ def parse_response_to_json(response_text: str):
188
+ if not response_text:
189
+ return None
190
+
191
+ s = response_text.strip()
192
+
193
+ if s.startswith('"') and s.endswith('"'):
194
+ s = s[1:-1].replace('\\"', '"')
195
+
196
+ try:
197
+ obj = json.loads(s)
198
+ if isinstance(obj, dict):
199
+ return obj
200
+ except json.JSONDecodeError:
201
+ pass
202
+
203
+ try:
204
+ obj = ast.literal_eval(s)
205
+ if isinstance(obj, dict):
206
+ return obj
207
+ except (ValueError, SyntaxError):
208
+ pass
209
+
210
+ json_pattern = r"\{[\s\S]*\}"
211
+ m = re.search(json_pattern, s)
212
+ if m:
213
+ chunk = m.group(0).strip()
214
+ try:
215
+ obj = ast.literal_eval(chunk)
216
+ if isinstance(obj, dict):
217
+ return obj
218
+ except Exception:
219
+ pass
220
+ try:
221
+ chunk2 = chunk.replace("'", '"')
222
+ obj = json.loads(chunk2)
223
+ if isinstance(obj, dict):
224
+ return obj
225
+ except Exception:
226
+ pass
227
+
228
+ return {"text": response_text}
229
+
230
+
231
+ def normalize_base64(image_base64: str) -> str:
232
+ if not image_base64:
233
+ return image_base64
234
+ image_base64 = image_base64.strip()
235
+ if image_base64.startswith("data:"):
236
+ parts = image_base64.split(",", 1)
237
+ if len(parts) == 2:
238
+ return parts[1]
239
+ return image_base64
240
+
241
+
242
+ def ocr_by_llm(image_base64: str, prompt: str) -> str:
243
+ pixel_values = load_image_from_base64(image_base64, max_num=6)
244
+ pixel_values = pixel_values.to(dtype=torch.float32, device=DEVICE)
245
+ with torch.no_grad():
246
+ response_message = model.chat(
247
+ tokenizer,
248
+ pixel_values,
249
+ prompt,
250
+ generation_config,
251
+ )
252
+ del pixel_values
253
+ return response_message
254
+
255
+
256
+ def predict(image_base64: str):
257
+ """
258
+ Hàm chính cho API: nhận base64 ảnh CCCD, trả về JSON các trường.
259
+ Dùng được cả qua UI Gradio và HF Inference API: /run/predict với {"data": ["<base64>"]}
260
+ """
261
+ image_base64 = normalize_base64(image_base64)
262
+ if not image_base64:
263
+ return {"error": "image_base64 is required"}
264
+ try:
265
+ response_message = ocr_by_llm(image_base64, PROMPT)
266
+ parsed = parse_response_to_json(response_message)
267
+ return parsed
268
+ except Exception as e:
269
+ return {"error": str(e)}
270
+
271
+
272
+ demo = gr.Interface(
273
+ fn=predict,
274
+ inputs=gr.Textbox(
275
+ lines=4,
276
+ label="image_base64",
277
+ placeholder="Dán chuỗi base64 của ảnh CCCD (có thể ở dạng )",
278
+ ),
279
+ outputs=gr.JSON(label="Kết quả OCR JSON"),
280
+ title="CCCD OCR API (Vintern-1B-v2)",
281
+ description=(
282
+ "API dùng Vintern-1B-v2 để đọc ảnh CCCD và trả về JSON 5 trường. "
283
+ "Gọi qua Inference API: POST /run/predict với body {\"data\": [\"<image_base64>\"]}."
284
+ ),
285
+ )
286
+
287
+
288
+ if __name__ == "__main__":
289
+ demo.launch()