Spaces:
Configuration error
Configuration error
| import torch | |
| from paddleocr import PaddleOCR | |
| # ββ Load model βββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| def get_model(checkpoint: str = "best.pt"): | |
| global _model | |
| if _model is None: | |
| print(f"[INFO] Loading model from {checkpoint}...") | |
| _model = RTDETR(checkpoint) | |
| return _model | |
| _orig_load = torch.load | |
| def _safe_load(*args, **kwargs): | |
| kwargs.setdefault("weights_only", False) | |
| return _orig_load(*args, **kwargs) | |
| torch.load = _safe_load | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import cv2, json, os | |
| from pathlib import Path | |
| from ultralytics import RTDETR | |
| # ββ Device: M1 dΓΉng MPS ββββββββββββββββββββββββββββββββββ | |
| DEVICE = ( | |
| "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| print(f"[INFO] Device: {DEVICE}") | |
| # ββ Class config βββββββββββββββββββββββββββββββββββββββββ | |
| CLASS_NAMES = ['note', 'part-drawing', 'table'] | |
| # Map sang tΓͺn chuαΊ©n theo Δα» bΓ i | |
| CLASS_DISPLAY = { | |
| 'note': 'Note', | |
| 'part-drawing': 'PartDrawing', | |
| 'table': 'Table', | |
| } | |
| COLORS = { | |
| 'note': (0, 165, 255), # cam | |
| 'part-drawing': (0, 200, 0), # xanh lΓ‘ | |
| 'table': (220, 0, 0), # Δα» | |
| } | |
| # ================== OCR Mα»I - HOαΊ T Δα»NG TRΓN MAC M1 + PP-OCRv5 ================== | |
| from paddleocr import PaddleOCR, PPStructureV3 # β SỬA α» ΔΓY: PPStructure β PPStructureV3 | |
| import cv2 | |
| _ocr_engine = None | |
| _table_engine = None | |
| def get_ocr(): | |
| """OCR thΖ°α»ng cho Note""" | |
| global _ocr_engine | |
| if _ocr_engine is None: | |
| _ocr_engine = PaddleOCR( | |
| use_textline_orientation=True, # thay cho use_angle_cls cΕ© | |
| lang="vi" | |
| ) | |
| return _ocr_engine | |
| def get_table_engine(): | |
| """Table structure recognition (giα»― rows/columns)""" | |
| global _table_engine | |
| if _table_engine is None: | |
| _table_engine = PPStructureV3() # β DΓNG PPStructureV3 | |
| return _table_engine | |
| def ocr_note(img_path): | |
| """OCR cho Note""" | |
| ocr = get_ocr() | |
| result = ocr.ocr(img_path) # KHΓNG dΓΉng cls=True nα»―a | |
| if result and result[0]: | |
| return "\n".join([line[1][0] for line in result[0]]) | |
| return "" | |
| def ocr_table(img_path): | |
| """OCR cho Table - Ζ°u tiΓͺn giα»― cαΊ₯u trΓΊc bαΊ£ng""" | |
| try: | |
| engine = get_table_engine() | |
| img = cv2.imread(img_path) | |
| result = engine(img) | |
| return str(result) # Expected output thΖ°α»ng chαΊ₯p nhαΊn dαΊ‘ng nΓ y | |
| except Exception as e: | |
| print(f"[WARN] Table structure failed: {e}, fallback to plain OCR") | |
| return ocr_note(img_path) | |
| # ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββ | |
| def run_pipeline( | |
| image_path: str, | |
| output_dir: str = "outputs", | |
| checkpoint: str = "best.pt", | |
| conf: float = 0.3, | |
| ) -> tuple[dict, str]: | |
| """ | |
| ChαΊ‘y full pipeline: detect β crop β OCR β JSON. | |
| Returns: (result_dict, visualized_image_path) | |
| """ | |
| image_path = str(image_path) | |
| img_name = Path(image_path).name | |
| stem = Path(image_path).stem | |
| crop_dir = Path(output_dir) / stem / "crops" | |
| crop_dir.mkdir(parents=True, exist_ok=True) | |
| # 1. Detect | |
| model = get_model(checkpoint) | |
| results = model( | |
| image_path, | |
| imgsz=1024, | |
| conf=conf, | |
| iou=0.5, | |
| device=DEVICE, | |
| verbose=False, | |
| ) | |
| img_bgr = cv2.imread(image_path) | |
| if img_bgr is None: | |
| raise ValueError(f"KhΓ΄ng Δα»c Δược αΊ£nh: {image_path}") | |
| objects = [] | |
| for i, box in enumerate(results[0].boxes): | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| cls_idx = int(box.cls[0]) | |
| conf_val = round(float(box.conf[0]), 4) | |
| cls_raw = CLASS_NAMES[cls_idx] | |
| cls_show = CLASS_DISPLAY[cls_raw] | |
| # 2. Crop | |
| pad = 4 # padding nhα» quanh bbox | |
| cx1 = max(0, x1 - pad) | |
| cy1 = max(0, y1 - pad) | |
| cx2 = min(img_bgr.shape[1], x2 + pad) | |
| cy2 = min(img_bgr.shape[0], y2 + pad) | |
| crop = img_bgr[cy1:cy2, cx1:cx2] | |
| crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg") | |
| cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95]) | |
| # 3. OCR | |
| ocr_content = None | |
| if cls_raw == 'note': | |
| ocr_content = ocr_note(crop_path) | |
| elif cls_raw == 'table': | |
| ocr_content = ocr_table(crop_path) | |
| objects.append({ | |
| "id": i + 1, | |
| "class": cls_show, | |
| "confidence": conf_val, | |
| "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}, | |
| "crop_path": crop_path, | |
| "ocr_content": ocr_content, | |
| }) | |
| # 4. VαΊ½ bbox lΓͺn αΊ£nh | |
| color = COLORS[cls_raw] | |
| cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2) | |
| label = f"{cls_show} {conf_val:.2f}" | |
| (tw, th), _ = cv2.getTextSize( | |
| label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) | |
| cv2.rectangle(img_bgr, | |
| (x1, y1 - th - 8), (x1 + tw + 4, y1), | |
| color, -1) | |
| cv2.putText(img_bgr, label, | |
| (x1 + 2, y1 - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, | |
| (255, 255, 255), 2) | |
| # 5. LΖ°u αΊ£nh visualize | |
| vis_path = str(Path(output_dir) / stem / "result_vis.jpg") | |
| cv2.imwrite(vis_path, img_bgr) | |
| # 6. LΖ°u JSON | |
| result = {"image": img_name, "objects": objects} | |
| json_path = str(Path(output_dir) / stem / "result.json") | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(result, f, ensure_ascii=False, indent=2) | |
| print(f"[β] {img_name}: {len(objects)} objects β {json_path}") | |
| return result, vis_path | |
| # ββ CLI test nhanh ββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import sys | |
| img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg" | |
| result, vis = run_pipeline(img) | |
| print(json.dumps(result, ensure_ascii=False, indent=2)) | |