Spaces:
Build error
Build error
refactor: yolov5 model loading
Browse files- src/model/yolov5.py +12 -7
src/model/yolov5.py
CHANGED
@@ -2,18 +2,23 @@
|
|
2 |
module to load yolov5* model from the ultralytics/yolov5 repo
|
3 |
'''
|
4 |
import torch
|
|
|
5 |
|
6 |
|
7 |
-
def load_model():
|
8 |
"""
|
9 |
It loads the YOLOv5s model from the PyTorch Hub
|
10 |
:return: A model
|
11 |
"""
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
model = load_model()
|
|
|
2 |
module to load yolov5* model from the ultralytics/yolov5 repo
|
3 |
'''
|
4 |
import torch
|
5 |
+
from src.core.logger import logger
|
6 |
|
7 |
|
8 |
+
def load_model(model_repo: str = "ultralytics/yolov5", model_name: str = "yolov5s6"):
|
9 |
"""
|
10 |
It loads the YOLOv5s model from the PyTorch Hub
|
11 |
:return: A model
|
12 |
"""
|
13 |
+
try:
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
_model = torch.hub.load(model_repo, model_name, device=device)
|
16 |
+
_model_agnostic = True # NMS class-agnostic
|
17 |
+
_model.amp = True # enable Automatic Mixed Precision (NMS) for inference
|
18 |
+
return _model
|
19 |
+
except Exception as e:
|
20 |
+
logger.debug("Exception Caught: {}".format(e))
|
21 |
+
finally:
|
22 |
+
logger.info(f"[{model_repo}] {model_name} loaded with AMP [Device: {device}]")
|
23 |
|
24 |
model = load_model()
|