onipot commited on
Commit
976767f
1 Parent(s): 05edfc7

fix upsampling err with torch 1.11+

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -2
  2. yolov5/models/experimental.py +10 -8
requirements.txt CHANGED
@@ -5,8 +5,8 @@ opencv-python-headless
5
  Pillow
6
  PyYAML>=5.3.1
7
  scipy>=1.4.1
8
- torch==1.10.1
9
- torchvision==0.11.2
10
  tqdm>=4.41.0
11
 
12
  # logging -------------------------------------
 
5
  Pillow
6
  PyYAML>=5.3.1
7
  scipy>=1.4.1
8
+ torch>=1.7.0
9
+ torchvision>=0.8.1
10
  tqdm>=4.41.0
11
 
12
  # logging -------------------------------------
yolov5/models/experimental.py CHANGED
@@ -94,21 +94,23 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
94
  model = Ensemble()
95
  for w in weights if isinstance(weights, list) else [weights]:
96
  ckpt = torch.load(attempt_download(w), map_location=map_location) # load
97
- if fuse:
98
- model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
99
- else:
100
- model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
101
-
102
  # Compatibility updates
103
  for m in model.modules():
104
- if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
 
105
  m.inplace = inplace # pytorch 1.7.0 compatibility
106
- if type(m) is Detect:
107
  if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
108
  delattr(m, 'anchor_grid')
109
  setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
110
- elif type(m) is Conv:
111
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
 
 
 
 
112
 
113
  if len(model) == 1:
114
  return model[-1] # return model
 
94
  model = Ensemble()
95
  for w in weights if isinstance(weights, list) else [weights]:
96
  ckpt = torch.load(attempt_download(w), map_location=map_location) # load
97
+ ckpt = (ckpt['ema'] or ckpt['model']).float() # FP32 model
98
+ model.append(ckpt.fuse().eval() if fuse else ckpt.eval())
 
 
 
99
  # Compatibility updates
100
  for m in model.modules():
101
+ t = type(m)
102
+ if t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
103
  m.inplace = inplace # pytorch 1.7.0 compatibility
104
+ if t is Detect:
105
  if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
106
  delattr(m, 'anchor_grid')
107
  setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
108
+ elif t is Conv:
109
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
110
+ elif t is nn.Upsample:
111
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
112
+ elif t is Conv:
113
+ m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
114
 
115
  if len(model) == 1:
116
  return model[-1] # return model