Spaces:
Running
Running
| import os | |
| import spacy | |
| import torch | |
| from doctr.models import ocr_predictor | |
| from loguru import logger | |
| class ModelManager: | |
| """Singleton model manager for pre-loading all models at startup.""" | |
| _instance = None | |
| _doctr_model = None | |
| _spacy_model = None | |
| _device = None | |
| _models_loaded = False | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(ModelManager, cls).__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| pass | |
| async def _load_models(self): | |
| """Load all models synchronously.""" | |
| logger.info("π Starting model pre-loading...") | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"π± Using device: {self._device}") | |
| # Load doctr model | |
| logger.info("π Loading doctr OCR model...") | |
| self._doctr_model = ocr_predictor(pretrained=True) | |
| self._doctr_model.det_predictor.model = ( | |
| self._doctr_model.det_predictor.model.to(self._device) | |
| ) | |
| self._doctr_model.reco_predictor.model = ( | |
| self._doctr_model.reco_predictor.model.to(self._device) | |
| ) | |
| logger.info("β Doctr model loaded successfully!") | |
| # Load spaCy model | |
| self._spacy_model = spacy.load(os.getenv("SPACY_MODEL_NAME", "en_core_web_sm")) | |
| logger.info(f"β spaCy model loaded successfully!") | |
| self._models_loaded = True | |
| logger.info("π All models loaded successfully!") | |
| def doctr_model(self): | |
| """Get the loaded doctr model.""" | |
| return self._doctr_model | |
| def spacy_model(self): | |
| """Get the loaded spaCy model.""" | |
| return self._spacy_model | |
| def device(self): | |
| """Get the device being used.""" | |
| return self._device | |
| def models_loaded(self): | |
| """Check if models are loaded.""" | |
| return self._models_loaded | |
| async def ensure_models_loaded(self): | |
| """Ensure models are loaded (async wrapper).""" | |
| if not self._models_loaded: | |
| await self._load_models() | |
| return True | |
| async def get_model_status(self): | |
| """Get status of all models.""" | |
| return { | |
| "doctr_model": self._doctr_model is not None, | |
| "spacy_model": self._spacy_model is not None, | |
| "device": str(self._device), | |
| "models_loaded": self._models_loaded, | |
| "spacy_model_name": os.getenv("SPACY_MODEL_NAME"), | |
| } | |
| # Global model manager instance | |
| model_manager = ModelManager() | |