Spaces:
Sleeping
Sleeping
| import torch | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from model import FFTCNN # Import the model architecture | |
| class ModelLoader: | |
| """ | |
| A class to load and hold the PyTorch CNN model. | |
| """ | |
| def __init__(self, model_repo_id: str, model_filename: str): | |
| """ | |
| Initializes the ModelLoader and loads the model. | |
| Args: | |
| model_repo_id (str): The repository ID on Hugging Face. | |
| model_filename (str): The name of the model file (.pth) in the repository. | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename) | |
| print("FFT CNN model loaded successfully.") | |
| def _load_fft_model(self, repo_id: str, filename: str): | |
| """ | |
| Downloads and loads the FFT CNN model from a Hugging Face Hub repository. | |
| Args: | |
| repo_id (str): The repository ID on Hugging Face. | |
| filename (str): The name of the model file (.pth) in the repository. | |
| Returns: | |
| The loaded PyTorch model object. | |
| """ | |
| print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}") | |
| try: | |
| # Download the model file from the Hub. It returns the cached path. | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| print(f"Model downloaded to: {model_path}") | |
| # Initialize the model architecture | |
| model = FFTCNN() | |
| # Load the saved weights (state_dict) into the model | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) | |
| # Set the model to evaluation mode | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error downloading or loading model from Hugging Face: {e}") | |
| raise | |
| # --- Global Model Instance --- | |
| MODEL_REPO_ID = 'rhnsa/real_forged_classifier' | |
| MODEL_FILENAME = 'fft_cnn_model_78.pth' | |
| models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME) | |