Nanobit glenn-jocher commited on
Commit
14b0abe
1 Parent(s): c0ffcdf

autoShape() default for PyTorch Hub models (#1692)

Browse files

* Add autoshape parameter

* Remove autoshape call in ReadMe

* Update hubconf.py

* file/URI inputs and autoshape check passthrough

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (3) hide show
  1. README.md +1 -1
  2. hubconf.py +14 -14
  3. models/common.py +26 -18
README.md CHANGED
@@ -106,7 +106,7 @@ import torch
106
  from PIL import Image
107
 
108
  # Model
109
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).autoshape() # for PIL/cv2/np inputs and NMS
110
 
111
  # Images
112
  img1 = Image.open('zidane.jpg')
 
106
  from PIL import Image
107
 
108
  # Model
109
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS
110
 
111
  # Images
112
  img1 = Image.open('zidane.jpg')
hubconf.py CHANGED
@@ -17,7 +17,7 @@ dependencies = ['torch', 'yaml']
17
  set_logging()
18
 
19
 
20
- def create(name, pretrained, channels, classes):
21
  """Creates a specified YOLOv5 model
22
 
23
  Arguments:
@@ -41,7 +41,8 @@ def create(name, pretrained, channels, classes):
41
  model.load_state_dict(state_dict, strict=False) # load
42
  if len(ckpt['model'].names) == classes:
43
  model.names = ckpt['model'].names # set class names attribute
44
- # model = model.autoshape() # for PIL/cv2/np inputs and NMS
 
45
  return model
46
 
47
  except Exception as e:
@@ -50,7 +51,7 @@ def create(name, pretrained, channels, classes):
50
  raise Exception(s) from e
51
 
52
 
53
- def yolov5s(pretrained=False, channels=3, classes=80):
54
  """YOLOv5-small model from https://github.com/ultralytics/yolov5
55
 
56
  Arguments:
@@ -61,10 +62,10 @@ def yolov5s(pretrained=False, channels=3, classes=80):
61
  Returns:
62
  pytorch model
63
  """
64
- return create('yolov5s', pretrained, channels, classes)
65
 
66
 
67
- def yolov5m(pretrained=False, channels=3, classes=80):
68
  """YOLOv5-medium model from https://github.com/ultralytics/yolov5
69
 
70
  Arguments:
@@ -75,10 +76,10 @@ def yolov5m(pretrained=False, channels=3, classes=80):
75
  Returns:
76
  pytorch model
77
  """
78
- return create('yolov5m', pretrained, channels, classes)
79
 
80
 
81
- def yolov5l(pretrained=False, channels=3, classes=80):
82
  """YOLOv5-large model from https://github.com/ultralytics/yolov5
83
 
84
  Arguments:
@@ -89,10 +90,10 @@ def yolov5l(pretrained=False, channels=3, classes=80):
89
  Returns:
90
  pytorch model
91
  """
92
- return create('yolov5l', pretrained, channels, classes)
93
 
94
 
95
- def yolov5x(pretrained=False, channels=3, classes=80):
96
  """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
97
 
98
  Arguments:
@@ -103,10 +104,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
103
  Returns:
104
  pytorch model
105
  """
106
- return create('yolov5x', pretrained, channels, classes)
107
 
108
 
109
- def custom(path_or_model='path/to/model.pt'):
110
  """YOLOv5-custom model from https://github.com/ultralytics/yolov5
111
 
112
  Arguments (3 options):
@@ -124,13 +125,12 @@ def custom(path_or_model='path/to/model.pt'):
124
  hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
125
  hub_model.load_state_dict(model.float().state_dict()) # load state_dict
126
  hub_model.names = model.names # class names
127
- return hub_model
128
 
129
 
130
  if __name__ == '__main__':
131
- model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example
132
  # model = custom(path_or_model='path/to/model.pt') # custom example
133
- model = model.autoshape() # for PIL/cv2/np inputs and NMS
134
 
135
  # Verify inference
136
  from PIL import Image
 
17
  set_logging()
18
 
19
 
20
+ def create(name, pretrained, channels, classes, autoshape):
21
  """Creates a specified YOLOv5 model
22
 
23
  Arguments:
 
41
  model.load_state_dict(state_dict, strict=False) # load
42
  if len(ckpt['model'].names) == classes:
43
  model.names = ckpt['model'].names # set class names attribute
44
+ if autoshape:
45
+ model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
46
  return model
47
 
48
  except Exception as e:
 
51
  raise Exception(s) from e
52
 
53
 
54
+ def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
55
  """YOLOv5-small model from https://github.com/ultralytics/yolov5
56
 
57
  Arguments:
 
62
  Returns:
63
  pytorch model
64
  """
65
+ return create('yolov5s', pretrained, channels, classes, autoshape)
66
 
67
 
68
+ def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
69
  """YOLOv5-medium model from https://github.com/ultralytics/yolov5
70
 
71
  Arguments:
 
76
  Returns:
77
  pytorch model
78
  """
79
+ return create('yolov5m', pretrained, channels, classes, autoshape)
80
 
81
 
82
+ def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
83
  """YOLOv5-large model from https://github.com/ultralytics/yolov5
84
 
85
  Arguments:
 
90
  Returns:
91
  pytorch model
92
  """
93
+ return create('yolov5l', pretrained, channels, classes, autoshape)
94
 
95
 
96
+ def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
97
  """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
98
 
99
  Arguments:
 
104
  Returns:
105
  pytorch model
106
  """
107
+ return create('yolov5x', pretrained, channels, classes, autoshape)
108
 
109
 
110
+ def custom(path_or_model='path/to/model.pt', autoshape=True):
111
  """YOLOv5-custom model from https://github.com/ultralytics/yolov5
112
 
113
  Arguments (3 options):
 
125
  hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
126
  hub_model.load_state_dict(model.float().state_dict()) # load state_dict
127
  hub_model.names = model.names # class names
128
+ return hub_model.autoshape() if autoshape else hub_model
129
 
130
 
131
  if __name__ == '__main__':
132
+ model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
133
  # model = custom(path_or_model='path/to/model.pt') # custom example
 
134
 
135
  # Verify inference
136
  from PIL import Image
models/common.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import math
4
  import numpy as np
 
5
  import torch
6
  import torch.nn as nn
7
  from PIL import Image, ImageDraw
@@ -143,35 +144,42 @@ class autoShape(nn.Module):
143
  super(autoShape, self).__init__()
144
  self.model = model.eval()
145
 
 
 
 
 
146
  def forward(self, imgs, size=640, augment=False, profile=False):
147
- # supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
148
- # opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
149
- # PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
150
- # numpy: imgs = np.zeros((720,1280,3)) # HWC
151
- # torch: imgs = torch.zeros(16,3,720,1280) # BCHW
152
- # multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
 
 
153
 
154
  p = next(self.model.parameters()) # for device and type
155
  if isinstance(imgs, torch.Tensor): # torch
156
  return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
157
 
158
  # Pre-process
159
- if not isinstance(imgs, list):
160
- imgs = [imgs]
161
  shape0, shape1 = [], [] # image and inference shapes
162
- batch = range(len(imgs)) # batch size
163
- for i in batch:
164
- imgs[i] = np.array(imgs[i]) # to numpy
165
- if imgs[i].shape[0] < 5: # image in CHW
166
- imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
167
- imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
168
- s = imgs[i].shape[:2] # HWC
 
169
  shape0.append(s) # image shape
170
  g = (size / max(s)) # gain
171
  shape1.append([y * g for y in s])
 
172
  shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
173
- x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
174
- x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
175
  x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
176
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
177
 
@@ -181,7 +189,7 @@ class autoShape(nn.Module):
181
  y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
182
 
183
  # Post-process
184
- for i in batch:
185
  scale_coords(shape1, y[i][:, :4], shape0[i])
186
 
187
  return Detections(imgs, y, self.names)
 
2
 
3
  import math
4
  import numpy as np
5
+ import requests
6
  import torch
7
  import torch.nn as nn
8
  from PIL import Image, ImageDraw
 
144
  super(autoShape, self).__init__()
145
  self.model = model.eval()
146
 
147
+ def autoshape(self):
148
+ print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
149
+ return self
150
+
151
  def forward(self, imgs, size=640, augment=False, profile=False):
152
+ # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
153
+ # filename: imgs = 'data/samples/zidane.jpg'
154
+ # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
155
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
156
+ # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
157
+ # numpy: = np.zeros((720,1280,3)) # HWC
158
+ # torch: = torch.zeros(16,3,720,1280) # BCHW
159
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
160
 
161
  p = next(self.model.parameters()) # for device and type
162
  if isinstance(imgs, torch.Tensor): # torch
163
  return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
164
 
165
  # Pre-process
166
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
 
167
  shape0, shape1 = [], [] # image and inference shapes
168
+ for i, im in enumerate(imgs):
169
+ if isinstance(im, str): # filename or uri
170
+ im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open
171
+ im = np.array(im) # to numpy
172
+ if im.shape[0] < 5: # image in CHW
173
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
174
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
175
+ s = im.shape[:2] # HWC
176
  shape0.append(s) # image shape
177
  g = (size / max(s)) # gain
178
  shape1.append([y * g for y in s])
179
+ imgs[i] = im # update
180
  shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
181
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
182
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
183
  x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
184
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
185
 
 
189
  y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
190
 
191
  # Post-process
192
+ for i in range(n):
193
  scale_coords(shape1, y[i][:, :4], shape0[i])
194
 
195
  return Detections(imgs, y, self.names)