cgerum glenn-jocher commited on
Commit
b133baa
1 Parent(s): 9ab561d

Add `device` argument to PyTorch Hub models (#3104)

Browse files

* Allow to manual selection of device for torchhub models

* single line device

nested torch.device(torch.device(device)) ok

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (1) hide show
  1. hubconf.py +21 -20
hubconf.py CHANGED
@@ -8,7 +8,7 @@ Usage:
8
  import torch
9
 
10
 
11
- def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
12
  """Creates a specified YOLOv5 model
13
 
14
  Arguments:
@@ -18,6 +18,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
18
  classes (int): number of model classes
19
  autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
20
  verbose (bool): print all information to screen
 
21
 
22
  Returns:
23
  YOLOv5 pytorch model
@@ -50,7 +51,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
50
  model.names = ckpt['model'].names # set class names attribute
51
  if autoshape:
52
  model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
53
- device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
54
  return model.to(device)
55
 
56
  except Exception as e:
@@ -59,49 +60,49 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
59
  raise Exception(s) from e
60
 
61
 
62
- def custom(path='path/to/model.pt', autoshape=True, verbose=True):
63
  # YOLOv5 custom or local model
64
- return _create(path, autoshape=autoshape, verbose=verbose)
65
 
66
 
67
- def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
68
  # YOLOv5-small model https://github.com/ultralytics/yolov5
69
- return _create('yolov5s', pretrained, channels, classes, autoshape, verbose)
70
 
71
 
72
- def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
73
  # YOLOv5-medium model https://github.com/ultralytics/yolov5
74
- return _create('yolov5m', pretrained, channels, classes, autoshape, verbose)
75
 
76
 
77
- def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
78
  # YOLOv5-large model https://github.com/ultralytics/yolov5
79
- return _create('yolov5l', pretrained, channels, classes, autoshape, verbose)
80
 
81
 
82
- def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
83
  # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
84
- return _create('yolov5x', pretrained, channels, classes, autoshape, verbose)
85
 
86
 
87
- def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
88
  # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
89
- return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose)
90
 
91
 
92
- def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
93
  # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
94
- return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose)
95
 
96
 
97
- def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
98
  # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
99
- return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose)
100
 
101
 
102
- def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
103
  # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
104
- return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose)
105
 
106
 
107
  if __name__ == '__main__':
 
8
  import torch
9
 
10
 
11
+ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
12
  """Creates a specified YOLOv5 model
13
 
14
  Arguments:
 
18
  classes (int): number of model classes
19
  autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
20
  verbose (bool): print all information to screen
21
+ device (str, torch.device, None): device to use for model parameters
22
 
23
  Returns:
24
  YOLOv5 pytorch model
 
51
  model.names = ckpt['model'].names # set class names attribute
52
  if autoshape:
53
  model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
54
+ device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
55
  return model.to(device)
56
 
57
  except Exception as e:
 
60
  raise Exception(s) from e
61
 
62
 
63
+ def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
64
  # YOLOv5 custom or local model
65
+ return _create(path, autoshape=autoshape, verbose=verbose, device=device)
66
 
67
 
68
+ def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
69
  # YOLOv5-small model https://github.com/ultralytics/yolov5
70
+ return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
71
 
72
 
73
+ def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
74
  # YOLOv5-medium model https://github.com/ultralytics/yolov5
75
+ return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
76
 
77
 
78
+ def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
79
  # YOLOv5-large model https://github.com/ultralytics/yolov5
80
+ return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
81
 
82
 
83
+ def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
84
  # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
85
+ return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
86
 
87
 
88
+ def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
89
  # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
90
+ return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
91
 
92
 
93
+ def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
94
  # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
95
+ return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
96
 
97
 
98
+ def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
99
  # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
100
+ return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
101
 
102
 
103
+ def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
104
  # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
105
+ return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
106
 
107
 
108
  if __name__ == '__main__':