TYH71 commited on
Commit
946ac89
1 Parent(s): 5a0b91b

refactor: yolov5 model loading

Browse files
Files changed (1) hide show
  1. src/model/yolov5.py +12 -7
src/model/yolov5.py CHANGED
@@ -2,18 +2,23 @@
2
  module to load yolov5* model from the ultralytics/yolov5 repo
3
  '''
4
  import torch
 
5
 
6
 
7
- def load_model():
8
  """
9
  It loads the YOLOv5s model from the PyTorch Hub
10
  :return: A model
11
  """
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- _model = torch.hub.load("ultralytics/yolov5", "yolov5m6", device=device)
14
- _model_agnostic = True # NMS class-agnostic
15
- _model.amp = True # enable Automatic Mixed Precision (NMS) for inference
16
- return _model
17
-
 
 
 
 
18
 
19
  model = load_model()
 
2
  module to load yolov5* model from the ultralytics/yolov5 repo
3
  '''
4
  import torch
5
+ from src.core.logger import logger
6
 
7
 
8
+ def load_model(model_repo: str = "ultralytics/yolov5", model_name: str = "yolov5s6"):
9
  """
10
  It loads the YOLOv5s model from the PyTorch Hub
11
  :return: A model
12
  """
13
+ try:
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ _model = torch.hub.load(model_repo, model_name, device=device)
16
+ _model_agnostic = True # NMS class-agnostic
17
+ _model.amp = True # enable Automatic Mixed Precision (NMS) for inference
18
+ return _model
19
+ except Exception as e:
20
+ logger.debug("Exception Caught: {}".format(e))
21
+ finally:
22
+ logger.info(f"[{model_repo}] {model_name} loaded with AMP [Device: {device}]")
23
 
24
  model = load_model()