File size: 930 Bytes
522c7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
6043556
 
522c7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6043556
522c7b8
6043556
522c7b8
6043556
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# models_loader.py
import torch
from ultralytics import YOLO

try:
    from ultralytics import RTDETR
except ImportError:
    RTDETR = None

try:
    from ultralytics import YOLOWorld
except ImportError:
    YOLOWorld = None

device = "cuda" if torch.cuda.is_available() else "cpu"

EXTENDED_MODELS = {
    "YOLOv11": "yolo11n.pt",
    "YOLOv10": "yolov10n.pt",
    "YOLOv9":  "yolov9c.pt",
    "YOLOv8":  "yolov8n.pt",
    "YOLOv7":  "yolov7.pt",
    "YOLOv6":  "yolov6n.pt",
    "YOLOv5":  "yolov5s.pt",
    "RT-DETR-l": "rtdetr-l.pt",
    "YOLOv8s-Worldv2": "yolov8s-worldv2.pt"
}

def load_model(model_choice):
    weights = EXTENDED_MODELS[model_choice]
    
    if model_choice == "RT-DETR-l" and RTDETR is not None:
        return RTDETR(weights).to(device)
    elif model_choice == "YOLOv8s-Worldv2" and YOLOWorld is not None:
        return YOLOWorld(weights).to(device)
    else:
        return YOLO(weights).to(device)