smallGroupProject / ui /model_loader.py
k23064919's picture
quickfix
b727fe2
import torch
import sys
from pathlib import Path
import config
from clearml import Task
from models.modelOne import modelOne
from models.modelTwo import BetterCNN
sys.path.append(str(Path(__file__).parent.parent))
MODEL_CLASSES = {
"modelOne": modelOne,
"betterCNN": BetterCNN
}
MODEL_ARTIFACT_NAME = 'best_model'
class ModelLoader:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.modelCache = {}
def loadFromClearml(self, modelName):
modelConfig = config.MODEL_CONFIGS.get(modelName)
if not modelConfig:
raise ValueError(f"ClearML configuration not found for model: {modelName}")
taskID = modelConfig['clearml_task_id']
className = modelConfig['class']
try:
print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
task = Task.get_task(task_id=taskID)
print("Available artifacts:", task.artifacts.keys())
artifact = task.artifacts.get(MODEL_ARTIFACT_NAME)
if artifact is None:
raise RuntimeError(
f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}"
)
modelPath = artifact.get_local_copy()
if modelPath is None:
raise RuntimeError(
f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)"
)
print(f"Weights downloaded to: {modelPath}")
# Load correct model class
ModelClass = MODEL_CLASSES[className]
model = ModelClass(noOfClasses=39)
# Load weights
stateDict = torch.load(modelPath, map_location=self.device)
model.load_state_dict(stateDict)
model.to(self.device)
model.eval()
return model
except Exception as e:
print(f"Error loading from ClearML for {modelName}: {e}")
raise RuntimeError(f"Failed to load model from ClearML: {e}")
def loadModel(self, modelName):
if modelName in self.modelCache:
return self.modelCache[modelName]
try:
model = self.loadFromClearml(modelName)
self.modelCache[modelName] = model
return model
except Exception as e:
raise RuntimeError(f"Could not load model {modelName}. Check ClearML connection.")