Spaces:
Running
Running
| import clip | |
| import torch | |
| import joblib | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| class ModelLoader: | |
| """ | |
| A class to load and hold the machine learning models. | |
| This ensures that models are loaded only once. | |
| """ | |
| def __init__(self, clip_model_name: str, svm_repo_id: str, svm_filename: str): | |
| """ | |
| Initializes the ModelLoader and loads the models. | |
| Args: | |
| clip_model_name (str): The name of the CLIP model to load (e.g., 'ViT-L/14'). | |
| svm_repo_id (str): The repository ID on Hugging Face (e.g., 'rhnsa/ai_human_image_detector'). | |
| svm_filename (str): The name of the model file in the repository (e.g., 'model.joblib'). | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self.clip_model, self.clip_preprocess = self._load_clip_model(clip_model_name) | |
| self.svm_model = self._load_svm_model(repo_id=svm_repo_id, filename=svm_filename) | |
| print("Models loaded successfully.") | |
| def _load_clip_model(self, model_name: str): | |
| """ | |
| Loads the specified CLIP model and its preprocessor. | |
| Args: | |
| model_name (str): The name of the CLIP model. | |
| Returns: | |
| A tuple containing the loaded CLIP model and its preprocess function. | |
| """ | |
| try: | |
| model, preprocess = clip.load(model_name, device=self.device) | |
| return model, preprocess | |
| except Exception as e: | |
| print(f"Error loading CLIP model: {e}") | |
| raise | |
| def _load_svm_model(self, repo_id: str, filename: str): | |
| """ | |
| Downloads and loads the SVM model from a Hugging Face Hub repository. | |
| Args: | |
| repo_id (str): The repository ID on Hugging Face. | |
| filename (str): The name of the model file in the repository. | |
| Returns: | |
| The loaded SVM model object. | |
| """ | |
| print(f"Downloading SVM model from Hugging Face repo: {repo_id}") | |
| try: | |
| # Download the model file from the Hub. It returns the cached path. | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| print(f"SVM model downloaded to: {model_path}") | |
| # Load the model from the downloaded path | |
| svm_model = joblib.load(model_path) | |
| return svm_model | |
| except Exception as e: | |
| print(f"Error downloading or loading SVM model from Hugging Face: {e}") | |
| raise | |
| # --- Global Model Instance --- | |
| # This creates a single instance of the models that can be imported by other modules. | |
| CLIP_MODEL_NAME = 'ViT-L/14' | |
| SVM_REPO_ID = 'rhnsa/ai_human_image_detector' | |
| SVM_FILENAME = 'svm_model_real.joblib' # The name of your model file in the Hugging Face repo | |
| # This instance will be created when the application starts. | |
| models = ModelLoader( | |
| clip_model_name=CLIP_MODEL_NAME, | |
| svm_repo_id=SVM_REPO_ID, | |
| svm_filename=SVM_FILENAME | |
| ) | |