import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import time import numpy as np import torch from PIL import Image from models.base_model import BaseModelMainModel from utils import configs from utils.functional import image_augmentations, active_learning_uncertainty from .lightning_module import ImageClassificationLightningModule class DeepLearningModel(BaseModelMainModel): def __init__( self, name_model: str, freeze_model: bool, pretrained_model: bool, support_set_method: str, ): super().__init__(name_model, freeze_model, pretrained_model, support_set_method) self.init_model() def init_model(self): self.model = ImageClassificationLightningModule.load_from_checkpoint( os.path.join( configs.WEIGHTS_PATH, self.name_model, self.support_set_method, "best.ckpt", ), name_model=self.name_model, freeze_model=self.freeze_model, pretrained_model=self.pretrained_model, ) self.model = self.model.model for layer in self.model.children(): if hasattr(layer, "reset_parameters") and not self.pretrained_model: layer.reset_parameters() for param in self.model.parameters(): param.required_grad = False if not self.freeze_model else True self.model.to(self.device) self.model.eval() def predict(self, image: np.ndarray) -> dict: image_input = image_augmentations()(image=image)["image"] image_input = image_input.unsqueeze(axis=0).to(self.device) with torch.no_grad(): start_time = time.perf_counter() result = self.model(image_input) end_time = time.perf_counter() - start_time result = torch.softmax(result, dim=1) result = result.detach().cpu().numpy() result_index = np.argmax(result) confidence = result[0][result_index] uncertainty_score = active_learning_uncertainty(result[0]) uncertainty_score = uncertainty_score if uncertainty_score > 0 else 0 if ( uncertainty_score > configs.NAME_MODELS[self.name_model][ "deep_learning_out_of_distribution_threshold" ][self.support_set_method] ): return { "character": configs.CLASS_CHARACTERS[-1], "confidence": confidence, "inference_time": end_time, } return { "character": configs.CLASS_CHARACTERS[result_index], "confidence": confidence, "inference_time": end_time, } if __name__ == "__main__": model = DeepLearningModel("resnet50", True, True, "1_shot") image = np.array( Image.open( "../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png" ).convert("RGB") ) result = model.predict(image) print(result)