Spaces:
Sleeping
Sleeping
from torchvision import models | |
import numpy as np | |
from torchvision.models import detection | |
import torch | |
import torchvision | |
import torchvision.models.segmentation as segmentation | |
from ultralytics import YOLO | |
from threading import Lock | |
# import tensorrt | |
# import tensorrt as trt | |
# import onnx | |
# import onnxruntime as ort | |
class TorchModelFactory: | |
_instance = None | |
_lock = Lock() | |
_feature_extract_models = {} | |
_detect_models = {} | |
_classification_models = {} | |
_instance_models = {} | |
_semantic_models = {} | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
MODELS_FEATURE_EXTRACT = { | |
'resnet': lambda: models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1), | |
'vgg16': lambda: models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1), | |
'inception_v3': lambda: models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1), | |
'mobilenet_v2': lambda: models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1), | |
'densenet121': lambda: models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1) | |
} | |
MODELS_DETECT = { | |
'RetinaNet': lambda: detection.retinanet_resnet50_fpn(weights=detection.RetinaNet_ResNet50_FPN_Weights.COCO_V1, | |
weights_backbone=models.ResNet50_Weights.IMAGENET1K_V1), | |
'FasterRCNN': lambda: detection.fasterrcnn_resnet50_fpn(weights=detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1, | |
weights_backbone=models.ResNet50_Weights.IMAGENET1K_V1), | |
'SSDLite': lambda: detection.ssd300_vgg16(weights=detection.SSD300_VGG16_Weights.COCO_V1), | |
'Yolo': lambda: YOLO("yolov8n.pt") | |
} | |
MODELS_CLASSIFICATION = { | |
'resnet': lambda: models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1), | |
'mobilenetv2': lambda: models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1), | |
'shufflenetv2': lambda: models.shufflenet_v2_x1_0(weights=models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1) | |
} | |
MODELS_INSTANCE = { | |
'maskrcnn': lambda: detection.maskrcnn_resnet50_fpn(weights=detection.MaskRCNN_ResNet50_FPN_Weights.COCO_V1), | |
'yolact': lambda: torch.hub.load('dbolya/yolact', 'yolact_resnet50', pretrained=True) | |
} | |
MODELS_SEMANTIC = { | |
'deeplabv3': lambda: segmentation.deeplabv3_resnet101(weights=segmentation.DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), | |
'pspnet': lambda: segmentation.pspnet_resnet50(pretrained=True), | |
'bisenetv1': lambda: torch.hub.load('catalyst-team/deeplabv3', 'deeplabv3_resnet50', pretrained=True) | |
} | |
def __new__(cls, *args, **kwargs): | |
if not cls._instance: | |
with cls._lock: | |
if not cls._instance: | |
cls._instance = super(TorchModelFactory, cls).__new__(cls) | |
return cls._instance | |
def create_feature_extract_model(model_name): | |
if model_name not in TorchModelFactory.MODELS_FEATURE_EXTRACT: | |
raise ValueError('Invalid model name') | |
if model_name not in TorchModelFactory._feature_extract_models: | |
with TorchModelFactory._lock: | |
if model_name not in TorchModelFactory._feature_extract_models: | |
model = TorchModelFactory.MODELS_FEATURE_EXTRACT[model_name]().to(TorchModelFactory.device) | |
model.eval() | |
TorchModelFactory._feature_extract_models[model_name] = model | |
return TorchModelFactory._feature_extract_models[model_name] | |
def create_detect_model(model_name): | |
if model_name not in TorchModelFactory.MODELS_DETECT: | |
raise ValueError('Invalid model name') | |
if model_name not in TorchModelFactory._detect_models: | |
with TorchModelFactory._lock: | |
if model_name not in TorchModelFactory._detect_models: | |
model = TorchModelFactory.MODELS_DETECT[model_name]().to(TorchModelFactory.device) | |
model.eval() | |
TorchModelFactory._detect_models[model_name] = model | |
return TorchModelFactory._detect_models[model_name] | |
def create_yolo_detect_model(): | |
if "Yolo" not in TorchModelFactory._detect_models: | |
with TorchModelFactory._lock: | |
if "Yolo" not in TorchModelFactory._detect_models: | |
model = TorchModelFactory.MODELS_DETECT["Yolo"]() | |
TorchModelFactory._detect_models["Yolo"] = model | |
return TorchModelFactory._detect_models["Yolo"] | |
def create_classication_model(model_name): | |
if model_name not in TorchModelFactory.MODELS_CLASSIFICATION: | |
raise ValueError('Invalid model name') | |
if model_name not in TorchModelFactory._classification_models: | |
with TorchModelFactory._lock: | |
if model_name not in TorchModelFactory._classification_models: | |
model = TorchModelFactory.MODELS_CLASSIFICATION[model_name]().to(TorchModelFactory.device) | |
model.eval() | |
TorchModelFactory._classification_models[model_name] = model | |
return TorchModelFactory._classification_models[model_name] | |
def create_instance_model(model_name): | |
if model_name not in TorchModelFactory.MODELS_INSTANCE: | |
raise ValueError('Invalid model name') | |
if model_name not in TorchModelFactory._instance_models: | |
with TorchModelFactory._lock: | |
if model_name not in TorchModelFactory._instance_models: | |
model = TorchModelFactory.MODELS_INSTANCE[model_name]().to(TorchModelFactory.device) | |
model.eval() | |
TorchModelFactory._instance_models[model_name] = model | |
return TorchModelFactory._instance_models[model_name] | |
def create_semantic_model(model_name): | |
if model_name not in TorchModelFactory.MODELS_SEMANTIC: | |
raise ValueError('Invalid model name') | |
if model_name not in TorchModelFactory._semantic_models: | |
with TorchModelFactory._lock: | |
if model_name not in TorchModelFactory._semantic_models: | |
model = TorchModelFactory.MODELS_SEMANTIC[model_name]().to(TorchModelFactory.device) | |
model.eval() | |
TorchModelFactory._semantic_models[model_name] = model | |
return TorchModelFactory._semantic_models[model_name] | |