andrew33333's picture
Upload folder using huggingface_hub
c147abc verified
"""Simple DocLayout model for inference."""
import json
from pathlib import Path
from typing import Dict, List, Union
import numpy as np
from PIL import Image
from ultralytics import YOLO
class DocLayoutModel:
"""
Document layout detection model.
Examples
--------
>>> model = DocLayoutModel("model.pt")
>>> results = model.predict("document.png")
>>> for det in results:
... print(f"{det['class_name']}: {det['confidence']:.2f}")
"""
# Default class mappings
DOCSTRUCTBENCH_CLASSES = {
0: "title",
1: "plain_text",
2: "abandon",
3: "figure",
4: "figure_caption",
5: "table",
6: "table_caption",
7: "table_footnote",
8: "isolate_formula",
9: "formula_caption",
}
DOCLAYNET_CLASSES = {
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
}
def __init__(
self,
weights_path: Union[str, Path],
config_path: Union[str, Path, None] = None,
model_type: str = "auto",
):
"""
Initialize model.
Parameters
----------
weights_path : str or Path
Path to model weights (.pt file)
config_path : str or Path, optional
Path to config.json with class names. If None, auto-detects from weights filename.
model_type : str, default="auto"
Model type: "docstructbench", "doclaynet", or "auto" (detect from filename)
"""
self.weights_path = Path(weights_path)
self._model = None
# Load class names from config or auto-detect
if config_path:
with open(config_path) as f:
config = json.load(f)
self.class_names = {i: name for i, name in enumerate(config["class_names"])}
else:
self.class_names = self._get_class_names(model_type)
def _get_class_names(self, model_type: str) -> Dict[int, str]:
"""Get class names based on model type."""
if model_type == "auto":
name = self.weights_path.stem.lower()
if "doclaynet" in name:
return self.DOCLAYNET_CLASSES
return self.DOCSTRUCTBENCH_CLASSES
elif model_type == "doclaynet":
return self.DOCLAYNET_CLASSES
elif model_type == "docstructbench":
return self.DOCSTRUCTBENCH_CLASSES
else:
raise ValueError(f"Unknown model type: {model_type}")
@property
def model(self) -> YOLO:
"""Lazy-load the YOLO model."""
if self._model is None:
self._model = YOLO(str(self.weights_path))
return self._model
def predict(
self,
source: Union[str, Path, Image.Image, np.ndarray],
confidence: float = 0.2,
image_size: int = 1024,
device: str = "cpu",
) -> List[Dict]:
"""
Run inference on an image.
Parameters
----------
source : str, Path, PIL.Image, or np.ndarray
Input image
confidence : float, default=0.2
Confidence threshold
image_size : int, default=1024
Input image size
device : str, default="cpu"
Device to run on ("cpu", "cuda", "mps")
Returns
-------
List[Dict]
List of detections, each with keys:
- class_id: int
- class_name: str
- confidence: float
- bbox: [x1, y1, x2, y2]
"""
results = self.model.predict(
source=str(source) if isinstance(source, Path) else source,
imgsz=image_size,
conf=confidence,
device=device,
save=False,
verbose=False,
)
detections = []
for result in results:
for box in result.boxes:
cls = int(box.cls[0])
detections.append(
{
"class_id": cls,
"class_name": self.class_names.get(cls, f"class_{cls}"),
"confidence": float(box.conf[0]),
"bbox": box.xyxy[0].tolist(),
}
)
return detections