Harry Pham
init project
ea9cf0f
import torch
from paddleocr import PaddleOCR
# ── Load model ───────────────────────────────────────────
_model = None
def get_model(checkpoint: str = "best.pt"):
global _model
if _model is None:
print(f"[INFO] Loading model from {checkpoint}...")
_model = RTDETR(checkpoint)
return _model
_orig_load = torch.load
def _safe_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _safe_load
# ─────────────────────────────────────────────────────────
import cv2, json, os
from pathlib import Path
from ultralytics import RTDETR
# ── Device: M1 dΓΉng MPS ──────────────────────────────────
DEVICE = (
"mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"[INFO] Device: {DEVICE}")
# ── Class config ─────────────────────────────────────────
CLASS_NAMES = ['note', 'part-drawing', 'table']
# Map sang tΓͺn chuαΊ©n theo đề bΓ i
CLASS_DISPLAY = {
'note': 'Note',
'part-drawing': 'PartDrawing',
'table': 'Table',
}
COLORS = {
'note': (0, 165, 255), # cam
'part-drawing': (0, 200, 0), # xanh lΓ‘
'table': (220, 0, 0), # đỏ
}
# ================== OCR MỚI - HOẠT ĐỘNG TRÊN MAC M1 + PP-OCRv5 ==================
from paddleocr import PaddleOCR, PPStructureV3 # ← SỬA Ở ĐÂY: PPStructure β†’ PPStructureV3
import cv2
_ocr_engine = None
_table_engine = None
def get_ocr():
"""OCR thường cho Note"""
global _ocr_engine
if _ocr_engine is None:
_ocr_engine = PaddleOCR(
use_textline_orientation=True, # thay cho use_angle_cls cΕ©
lang="vi"
)
return _ocr_engine
def get_table_engine():
"""Table structure recognition (giα»― rows/columns)"""
global _table_engine
if _table_engine is None:
_table_engine = PPStructureV3() # ← DΓ™NG PPStructureV3
return _table_engine
def ocr_note(img_path):
"""OCR cho Note"""
ocr = get_ocr()
result = ocr.ocr(img_path) # KHΓ”NG dΓΉng cls=True nα»―a
if result and result[0]:
return "\n".join([line[1][0] for line in result[0]])
return ""
def ocr_table(img_path):
"""OCR cho Table - Ζ°u tiΓͺn giα»― cαΊ₯u trΓΊc bαΊ£ng"""
try:
engine = get_table_engine()
img = cv2.imread(img_path)
result = engine(img)
return str(result) # Expected output thường chαΊ₯p nhαΊ­n dαΊ‘ng nΓ y
except Exception as e:
print(f"[WARN] Table structure failed: {e}, fallback to plain OCR")
return ocr_note(img_path)
# ── Main pipeline ─────────────────────────────────────────
def run_pipeline(
image_path: str,
output_dir: str = "outputs",
checkpoint: str = "best.pt",
conf: float = 0.3,
) -> tuple[dict, str]:
"""
ChαΊ‘y full pipeline: detect β†’ crop β†’ OCR β†’ JSON.
Returns: (result_dict, visualized_image_path)
"""
image_path = str(image_path)
img_name = Path(image_path).name
stem = Path(image_path).stem
crop_dir = Path(output_dir) / stem / "crops"
crop_dir.mkdir(parents=True, exist_ok=True)
# 1. Detect
model = get_model(checkpoint)
results = model(
image_path,
imgsz=1024,
conf=conf,
iou=0.5,
device=DEVICE,
verbose=False,
)
img_bgr = cv2.imread(image_path)
if img_bgr is None:
raise ValueError(f"KhΓ΄ng đọc được αΊ£nh: {image_path}")
objects = []
for i, box in enumerate(results[0].boxes):
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
cls_idx = int(box.cls[0])
conf_val = round(float(box.conf[0]), 4)
cls_raw = CLASS_NAMES[cls_idx]
cls_show = CLASS_DISPLAY[cls_raw]
# 2. Crop
pad = 4 # padding nhỏ quanh bbox
cx1 = max(0, x1 - pad)
cy1 = max(0, y1 - pad)
cx2 = min(img_bgr.shape[1], x2 + pad)
cy2 = min(img_bgr.shape[0], y2 + pad)
crop = img_bgr[cy1:cy2, cx1:cx2]
crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg")
cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95])
# 3. OCR
ocr_content = None
if cls_raw == 'note':
ocr_content = ocr_note(crop_path)
elif cls_raw == 'table':
ocr_content = ocr_table(crop_path)
objects.append({
"id": i + 1,
"class": cls_show,
"confidence": conf_val,
"bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"crop_path": crop_path,
"ocr_content": ocr_content,
})
# 4. VαΊ½ bbox lΓͺn αΊ£nh
color = COLORS[cls_raw]
cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2)
label = f"{cls_show} {conf_val:.2f}"
(tw, th), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(img_bgr,
(x1, y1 - th - 8), (x1 + tw + 4, y1),
color, -1)
cv2.putText(img_bgr, label,
(x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
# 5. LΖ°u αΊ£nh visualize
vis_path = str(Path(output_dir) / stem / "result_vis.jpg")
cv2.imwrite(vis_path, img_bgr)
# 6. LΖ°u JSON
result = {"image": img_name, "objects": objects}
json_path = str(Path(output_dir) / stem / "result.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"[βœ“] {img_name}: {len(objects)} objects β†’ {json_path}")
return result, vis_path
# ── CLI test nhanh ────────────────────────────────────────
if __name__ == "__main__":
import sys
img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
result, vis = run_pipeline(img)
print(json.dumps(result, ensure_ascii=False, indent=2))