glenn-jocher commited on
Commit
8c326a1
1 Parent(s): 5d4258f

Meshgrid `indexing='ij'` for PyTorch 1.10 (#5309)

Browse files

* Meshgrid `indexing='ij'` for PyTorch 1.10

Will not merge currently as breaks backwards compatibility.

* Meshgrid `indexing='ij'` for PyTorch 1.10

Will not merge currently as breaks backwards compatibility.

* Add check_version hard argument

* Update comment

Files changed (3) hide show
  1. models/yolo.py +5 -2
  2. utils/augmentations.py +1 -1
  3. utils/general.py +7 -4
models/yolo.py CHANGED
@@ -20,7 +20,7 @@ if str(ROOT) not in sys.path:
20
  from models.common import *
21
  from models.experimental import *
22
  from utils.autoanchor import check_anchor_order
23
- from utils.general import check_yaml, make_divisible, print_args, set_logging
24
  from utils.plots import feature_visualization
25
  from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
26
  select_device, time_sync
@@ -74,7 +74,10 @@ class Detect(nn.Module):
74
 
75
  def _make_grid(self, nx=20, ny=20, i=0):
76
  d = self.anchors[i].device
77
- yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
 
 
 
78
  grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
79
  anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
80
  .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
 
20
  from models.common import *
21
  from models.experimental import *
22
  from utils.autoanchor import check_anchor_order
23
+ from utils.general import check_yaml, make_divisible, print_args, set_logging, check_version
24
  from utils.plots import feature_visualization
25
  from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
26
  select_device, time_sync
 
74
 
75
  def _make_grid(self, nx=20, ny=20, i=0):
76
  d = self.anchors[i].device
77
+ if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
78
+ yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)], indexing='ij')
79
+ else:
80
+ yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
81
  grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
82
  anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
83
  .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
utils/augmentations.py CHANGED
@@ -20,7 +20,7 @@ class Albumentations:
20
  self.transform = None
21
  try:
22
  import albumentations as A
23
- check_version(A.__version__, '1.0.3') # version requirement
24
 
25
  self.transform = A.Compose([
26
  A.Blur(p=0.01),
 
20
  self.transform = None
21
  try:
22
  import albumentations as A
23
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
24
 
25
  self.transform = A.Compose([
26
  A.Blur(p=0.01),
utils/general.py CHANGED
@@ -220,14 +220,17 @@ def check_git_status():
220
 
221
  def check_python(minimum='3.6.2'):
222
  # Check current python version vs. required python version
223
- check_version(platform.python_version(), minimum, name='Python ')
224
 
225
 
226
- def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
227
  # Check version vs. required version
228
  current, minimum = (pkg.parse_version(x) for x in (current, minimum))
229
- result = (current == minimum) if pinned else (current >= minimum)
230
- assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
 
 
 
231
 
232
 
233
  @try_except
 
220
 
221
  def check_python(minimum='3.6.2'):
222
  # Check current python version vs. required python version
223
+ check_version(platform.python_version(), minimum, name='Python ', hard=True)
224
 
225
 
226
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
227
  # Check version vs. required version
228
  current, minimum = (pkg.parse_version(x) for x in (current, minimum))
229
+ result = (current == minimum) if pinned else (current >= minimum) # bool
230
+ if hard: # assert min requirements met
231
+ assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
232
+ else:
233
+ return result
234
 
235
 
236
  @try_except