Spaces:
Sleeping
Sleeping
| import torch | |
| import sys | |
| from pathlib import Path | |
| import config | |
| from clearml import Task | |
| from models.modelOne import modelOne | |
| from models.modelTwo import BetterCNN | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| MODEL_CLASSES = { | |
| "modelOne": modelOne, | |
| "betterCNN": BetterCNN | |
| } | |
| MODEL_ARTIFACT_NAME = 'best_model' | |
| class ModelLoader: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.modelCache = {} | |
| def loadFromClearml(self, modelName): | |
| modelConfig = config.MODEL_CONFIGS.get(modelName) | |
| if not modelConfig: | |
| raise ValueError(f"ClearML configuration not found for model: {modelName}") | |
| taskID = modelConfig['clearml_task_id'] | |
| className = modelConfig['class'] | |
| try: | |
| print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}") | |
| task = Task.get_task(task_id=taskID) | |
| print("Available artifacts:", task.artifacts.keys()) | |
| artifact = task.artifacts.get(MODEL_ARTIFACT_NAME) | |
| if artifact is None: | |
| raise RuntimeError( | |
| f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}" | |
| ) | |
| modelPath = artifact.get_local_copy() | |
| if modelPath is None: | |
| raise RuntimeError( | |
| f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)" | |
| ) | |
| print(f"Weights downloaded to: {modelPath}") | |
| # Load correct model class | |
| ModelClass = MODEL_CLASSES[className] | |
| model = ModelClass(noOfClasses=39) | |
| # Load weights | |
| stateDict = torch.load(modelPath, map_location=self.device) | |
| model.load_state_dict(stateDict) | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error loading from ClearML for {modelName}: {e}") | |
| raise RuntimeError(f"Failed to load model from ClearML: {e}") | |
| def loadModel(self, modelName): | |
| if modelName in self.modelCache: | |
| return self.modelCache[modelName] | |
| try: | |
| model = self.loadFromClearml(modelName) | |
| self.modelCache[modelName] = model | |
| return model | |
| except Exception as e: | |
| raise RuntimeError(f"Could not load model {modelName}. Check ClearML connection.") | |