from torchvision import models import numpy as np from torchvision.models import detection import torch import torchvision import torchvision.models.segmentation as segmentation from ultralytics import YOLO from threading import Lock # import tensorrt # import tensorrt as trt # import onnx # import onnxruntime as ort class TorchModelFactory: _instance = None _lock = Lock() _feature_extract_models = {} _detect_models = {} _classification_models = {} _instance_models = {} _semantic_models = {} device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MODELS_FEATURE_EXTRACT = { 'resnet': lambda: models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1), 'vgg16': lambda: models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1), 'inception_v3': lambda: models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1), 'mobilenet_v2': lambda: models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1), 'densenet121': lambda: models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1) } MODELS_DETECT = { 'RetinaNet': lambda: detection.retinanet_resnet50_fpn(weights=detection.RetinaNet_ResNet50_FPN_Weights.COCO_V1, weights_backbone=models.ResNet50_Weights.IMAGENET1K_V1), 'FasterRCNN': lambda: detection.fasterrcnn_resnet50_fpn(weights=detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1, weights_backbone=models.ResNet50_Weights.IMAGENET1K_V1), 'SSDLite': lambda: detection.ssd300_vgg16(weights=detection.SSD300_VGG16_Weights.COCO_V1), 'Yolo': lambda: YOLO("yolov8n.pt") } MODELS_CLASSIFICATION = { 'resnet': lambda: models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1), 'mobilenetv2': lambda: models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1), 'shufflenetv2': lambda: models.shufflenet_v2_x1_0(weights=models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1) } MODELS_INSTANCE = { 'maskrcnn': lambda: detection.maskrcnn_resnet50_fpn(weights=detection.MaskRCNN_ResNet50_FPN_Weights.COCO_V1), 'yolact': lambda: torch.hub.load('dbolya/yolact', 'yolact_resnet50', pretrained=True) } MODELS_SEMANTIC = { 'deeplabv3': lambda: segmentation.deeplabv3_resnet101(weights=segmentation.DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), 'pspnet': lambda: segmentation.pspnet_resnet50(pretrained=True), 'bisenetv1': lambda: torch.hub.load('catalyst-team/deeplabv3', 'deeplabv3_resnet50', pretrained=True) } def __new__(cls, *args, **kwargs): if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super(TorchModelFactory, cls).__new__(cls) return cls._instance @staticmethod def create_feature_extract_model(model_name): if model_name not in TorchModelFactory.MODELS_FEATURE_EXTRACT: raise ValueError('Invalid model name') if model_name not in TorchModelFactory._feature_extract_models: with TorchModelFactory._lock: if model_name not in TorchModelFactory._feature_extract_models: model = TorchModelFactory.MODELS_FEATURE_EXTRACT[model_name]().to(TorchModelFactory.device) model.eval() TorchModelFactory._feature_extract_models[model_name] = model return TorchModelFactory._feature_extract_models[model_name] @staticmethod def create_detect_model(model_name): if model_name not in TorchModelFactory.MODELS_DETECT: raise ValueError('Invalid model name') if model_name not in TorchModelFactory._detect_models: with TorchModelFactory._lock: if model_name not in TorchModelFactory._detect_models: model = TorchModelFactory.MODELS_DETECT[model_name]().to(TorchModelFactory.device) model.eval() TorchModelFactory._detect_models[model_name] = model return TorchModelFactory._detect_models[model_name] @staticmethod def create_yolo_detect_model(): if "Yolo" not in TorchModelFactory._detect_models: with TorchModelFactory._lock: if "Yolo" not in TorchModelFactory._detect_models: model = TorchModelFactory.MODELS_DETECT["Yolo"]() TorchModelFactory._detect_models["Yolo"] = model return TorchModelFactory._detect_models["Yolo"] @staticmethod def create_classication_model(model_name): if model_name not in TorchModelFactory.MODELS_CLASSIFICATION: raise ValueError('Invalid model name') if model_name not in TorchModelFactory._classification_models: with TorchModelFactory._lock: if model_name not in TorchModelFactory._classification_models: model = TorchModelFactory.MODELS_CLASSIFICATION[model_name]().to(TorchModelFactory.device) model.eval() TorchModelFactory._classification_models[model_name] = model return TorchModelFactory._classification_models[model_name] @staticmethod def create_instance_model(model_name): if model_name not in TorchModelFactory.MODELS_INSTANCE: raise ValueError('Invalid model name') if model_name not in TorchModelFactory._instance_models: with TorchModelFactory._lock: if model_name not in TorchModelFactory._instance_models: model = TorchModelFactory.MODELS_INSTANCE[model_name]().to(TorchModelFactory.device) model.eval() TorchModelFactory._instance_models[model_name] = model return TorchModelFactory._instance_models[model_name] @staticmethod def create_semantic_model(model_name): if model_name not in TorchModelFactory.MODELS_SEMANTIC: raise ValueError('Invalid model name') if model_name not in TorchModelFactory._semantic_models: with TorchModelFactory._lock: if model_name not in TorchModelFactory._semantic_models: model = TorchModelFactory.MODELS_SEMANTIC[model_name]().to(TorchModelFactory.device) model.eval() TorchModelFactory._semantic_models[model_name] = model return TorchModelFactory._semantic_models[model_name]