arabic-summarizer-classifier / model_manager.py
mabosaimi's picture
Fkhrayef (#1)
5fc9256 verified
from typing import Dict, Any
import os
from traditional_classifier import TraditionalClassifier
try:
from modern_classifier import ModernClassifier
MODERN_MODELS_AVAILABLE = True
except ImportError:
MODERN_MODELS_AVAILABLE = False
class ModelManager:
"""Manages different types of Arabic text classification models with per-request model selection and caching."""
AVAILABLE_MODELS = {
"traditional_svm": {
"type": "traditional",
"classifier_path": "models/traditional_svm_classifier.joblib",
"vectorizer_path": "models/traditional_tfidf_vectorizer_classifier.joblib",
"description": "Traditional SVM classifier with TF-IDF vectorization"
},
"modern_bert": {
"type": "modern",
"model_type": "bert",
"model_path": "models/modern_bert_classifier.safetensors",
"config_path": "config.json",
"description": "Modern BERT-based transformer classifier"
},
"modern_lstm": {
"type": "modern",
"model_type": "lstm",
"model_path": "models/modern_lstm_classifier.pth",
"description": "Modern LSTM-based neural network classifier"
}
}
def __init__(self, default_model: str = "traditional_svm"):
self.default_model = default_model
self._model_cache = {}
def _get_model(self, model_name: str):
"""Get model instance, loading from cache or creating new one."""
if model_name not in self.AVAILABLE_MODELS:
raise ValueError(f"Model '{model_name}' not available. Available models: {list(self.AVAILABLE_MODELS.keys())}")
if model_name in self._model_cache:
return self._model_cache[model_name]
model_config = self.AVAILABLE_MODELS[model_name]
if model_config["type"] == "traditional":
classifier_path = model_config["classifier_path"]
vectorizer_path = model_config["vectorizer_path"]
if not os.path.exists(classifier_path):
raise FileNotFoundError(f"Classifier file not found: {classifier_path}")
if not os.path.exists(vectorizer_path):
raise FileNotFoundError(f"Vectorizer file not found: {vectorizer_path}")
model = TraditionalClassifier(classifier_path, vectorizer_path)
elif model_config["type"] == "modern":
if not MODERN_MODELS_AVAILABLE:
raise ImportError("Modern models require PyTorch and transformers")
model_path = model_config["model_path"]
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
config_path = model_config.get("config_path")
if config_path and not os.path.exists(config_path):
config_path = None
model = ModernClassifier(
model_type=model_config["model_type"],
model_path=model_path,
config_path=config_path
)
self._model_cache[model_name] = model
return model
def predict(self, text: str, model_name: str = None) -> Dict[str, Any]:
"""Predict using the specified model (or default if none specified)."""
if model_name is None:
model_name = self.default_model
model = self._get_model(model_name)
result = model.predict(text)
result["model_manager"] = {
"model_used": model_name,
"model_description": self.AVAILABLE_MODELS[model_name]["description"]
}
return result
def predict_batch(self, texts: list, model_name: str = None) -> list:
"""Predict batch using the specified model (or default if none specified)."""
if model_name is None:
model_name = self.default_model
model = self._get_model(model_name)
results = model.predict_batch(texts)
for result in results:
result["model_manager"] = {
"model_used": model_name,
"model_description": self.AVAILABLE_MODELS[model_name]["description"]
}
return results
def get_model_info(self, model_name: str = None) -> Dict[str, Any]:
"""Get information about a specific model (or default if none specified)."""
if model_name is None:
model_name = self.default_model
model = self._get_model(model_name)
model_info = model.get_model_info()
model_info.update({
"model_manager": {
"model_name": model_name,
"model_description": self.AVAILABLE_MODELS[model_name]["description"],
"model_config": self.AVAILABLE_MODELS[model_name],
"is_cached": model_name in self._model_cache
}
})
return model_info
def get_available_models(self) -> Dict[str, Any]:
"""Get list of all available models."""
available = {}
for model_name, config in self.AVAILABLE_MODELS.items():
files_exist = True
missing_files = []
if config["type"] == "traditional":
for file_key in ["classifier_path", "vectorizer_path"]:
if not os.path.exists(config[file_key]):
files_exist = False
missing_files.append(config[file_key])
elif config["type"] == "modern":
if not os.path.exists(config["model_path"]):
files_exist = False
missing_files.append(config["model_path"])
available[model_name] = {
"description": config["description"],
"type": config["type"],
"available": files_exist,
"missing_files": missing_files if not files_exist else [],
"is_default": model_name == self.default_model,
"is_cached": model_name in self._model_cache
}
return available
def clear_cache(self, model_name: str = None) -> Dict[str, Any]:
"""Clear model cache (specific model or all models)."""
if model_name:
if model_name in self._model_cache:
del self._model_cache[model_name]
return {"message": f"Cache cleared for model: {model_name}"}
else:
return {"message": f"Model {model_name} was not cached"}
else:
cleared_count = len(self._model_cache)
self._model_cache.clear()
return {"message": f"Cache cleared for {cleared_count} models"}
def get_cache_status(self) -> Dict[str, Any]:
"""Get information about cached models."""
return {
"cached_models": list(self._model_cache.keys()),
"cache_count": len(self._model_cache),
"default_model": self.default_model
}