Spaces:
Build error
Build error
| # Standard library imports first | |
| import os | |
| import math | |
| import json | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from dataclasses import replace | |
| from math import sqrt | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Dict, Any | |
| # Get logger before any other imports | |
| from logger import get_logger | |
| logger = get_logger(__name__) | |
| # Third-party imports | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from skimage.morphology import skeletonize | |
| from skimage.measure import label | |
| from ultralytics import YOLO | |
| # Local imports | |
| from storage import StorageInterface | |
| from base import BaseDetector | |
| from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig | |
| from line_detectors import OpenCVLineDetector, DEEPLSD_AVAILABLE | |
| # Try to import DeepLSD, but don't fail if not available | |
| try: | |
| from line_detectors import DeepLSDDetector | |
| logger.info("Successfully imported DeepLSD") | |
| except ImportError as e: | |
| logger.warning(f"DeepLSD import failed: {str(e)}. Will use OpenCV fallback.") | |
| # Detection schema imports | |
| from detection_schema import ( | |
| BBox, Coordinates, Point, Line, Symbol, Tag, | |
| SymbolType, LineStyle, ConnectionType, JunctionType, Junction | |
| ) | |
| # Rest of the classes... | |
| class Detector(ABC): | |
| """Base class for all detectors""" | |
| def __init__(self, config: Any, debug_handler=None): | |
| self.config = config | |
| self.debug_handler = debug_handler | |
| def detect(self, image: np.ndarray) -> Dict: | |
| """Perform detection on the image""" | |
| pass | |
| def save_debug_image(self, image: np.ndarray, filename: str): | |
| """Save debug visualization if debug handler is available""" | |
| if self.debug_handler: | |
| self.debug_handler.save_image(image, filename) | |
| class SymbolDetector(Detector): | |
| """Detector for symbols in P&ID diagrams""" | |
| def __init__(self, config, debug_handler=None): | |
| super().__init__(config, debug_handler) | |
| self.models = {} | |
| for name, path in config.model_paths.items(): | |
| if os.path.exists(path): | |
| self.models[name] = YOLO(path) | |
| else: | |
| logger.warning(f"Model not found at {path}") | |
| def detect(self, image: np.ndarray) -> Dict: | |
| """Detect symbols using multiple YOLO models""" | |
| results = [] | |
| # Process with each model | |
| for model_name, model in self.models.items(): | |
| model_results = model(image, conf=self.config.confidence_threshold)[0] | |
| boxes = model_results.boxes | |
| for box in boxes: | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| conf = box.conf[0].cpu().numpy() | |
| cls = box.cls[0].cpu().numpy() | |
| cls_name = model_results.names[int(cls)] | |
| results.append({ | |
| 'bbox': [float(x1), float(y1), float(x2), float(y2)], | |
| 'confidence': float(conf), | |
| 'class': cls_name, | |
| 'model': model_name | |
| }) | |
| return {'detections': results} | |
| class TagDetector(Detector): | |
| """Detector for text tags in P&ID diagrams""" | |
| def __init__(self, config, debug_handler=None): | |
| super().__init__(config, debug_handler) | |
| self.ocr = None # Initialize OCR engine here | |
| def detect(self, image: np.ndarray) -> Dict: | |
| """Detect and recognize text tags""" | |
| # Implement text detection logic | |
| return {'detections': []} | |
| class LineDetector(Detector): | |
| """Detector for lines in P&ID diagrams""" | |
| def __init__(self, config, model_path=None, model_config=None, device='cpu', debug_handler=None): | |
| super().__init__(config, debug_handler) | |
| # Try to use DeepLSD if available, otherwise fall back to OpenCV | |
| if DEEPLSD_AVAILABLE and model_path: | |
| self.detector = DeepLSDDetector(model_path) | |
| logger.info("Using DeepLSD for line detection") | |
| else: | |
| self.detector = OpenCVLineDetector() | |
| logger.info("Using OpenCV for line detection") | |
| def detect(self, image: np.ndarray) -> Dict: | |
| return self.detector.detect(image) | |
| class PointDetector(Detector): | |
| """Detector for connection points in P&ID diagrams""" | |
| def detect(self, image: np.ndarray) -> Dict: | |
| """Detect connection points""" | |
| # Implement point detection logic | |
| return {'detections': []} | |
| class JunctionDetector(Detector): | |
| """Detector for line junctions in P&ID diagrams""" | |
| def detect(self, image: np.ndarray) -> Dict: | |
| """Detect line junctions""" | |
| # Implement junction detection logic | |
| return {'detections': []} | |