glenn-jocher commited on
Commit
c215878
β€’
1 Parent(s): 27911dc

YOLOv5 Apple Metal Performance Shader (MPS) support (#7878)

Browse files

* Apple Metal Performance Shader (MPS) device support

Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/

Should work with Apple M1 devices with PyTorch nightly installed with command `--device mps`. Usage examples:
```bash
python train.py --device mps
python detect.py --device mps
python val.py --device mps
```

* Update device strategy to fix MPS issue

export.py CHANGED
@@ -486,7 +486,7 @@ def run(
486
  if half:
487
  assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0'
488
  assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
489
- model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
490
  nc, names = model.nc, model.names # number of classes, class names
491
 
492
  # Checks
 
486
  if half:
487
  assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0'
488
  assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
489
+ model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
490
  nc, names = model.nc, model.names # number of classes, class names
491
 
492
  # Checks
models/common.py CHANGED
@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module):
331
  names = yaml.safe_load(f)['names']
332
 
333
  if pt: # PyTorch
334
- model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
335
  stride = max(int(model.stride.max()), 32) # model stride
336
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
337
  model.half() if fp16 else model.float()
 
331
  names = yaml.safe_load(f)['names']
332
 
333
  if pt: # PyTorch
334
+ model = attempt_load(weights if isinstance(weights, list) else w, device=device)
335
  stride = max(int(model.stride.max()), 32) # model stride
336
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
337
  model.half() if fp16 else model.float()
models/experimental.py CHANGED
@@ -71,14 +71,14 @@ class Ensemble(nn.ModuleList):
71
  return y, None # inference, train output
72
 
73
 
74
- def attempt_load(weights, map_location=None, inplace=True, fuse=True):
75
  from models.yolo import Detect, Model
76
 
77
  # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
78
  model = Ensemble()
79
  for w in weights if isinstance(weights, list) else [weights]:
80
- ckpt = torch.load(attempt_download(w), map_location=map_location) # load
81
- ckpt = (ckpt.get('ema') or ckpt['model']).float() # FP32 model
82
  model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
83
 
84
  # Compatibility updates
 
71
  return y, None # inference, train output
72
 
73
 
74
+ def attempt_load(weights, device=None, inplace=True, fuse=True):
75
  from models.yolo import Detect, Model
76
 
77
  # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
78
  model = Ensemble()
79
  for w in weights if isinstance(weights, list) else [weights]:
80
+ ckpt = torch.load(attempt_download(w))
81
+ ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
82
  model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
83
 
84
  # Compatibility updates
models/tf.py CHANGED
@@ -536,7 +536,7 @@ def run(
536
  ):
537
  # PyTorch model
538
  im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
539
- model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
540
  _ = model(im) # inference
541
  model.info()
542
 
 
536
  ):
537
  # PyTorch model
538
  im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
539
+ model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
540
  _ = model(im) # inference
541
  model.info()
542
 
utils/torch_utils.py CHANGED
@@ -54,7 +54,8 @@ def select_device(device='', batch_size=0, newline=True):
54
  s = f'YOLOv5 πŸš€ {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
55
  device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
56
  cpu = device == 'cpu'
57
- if cpu:
 
58
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
59
  elif device: # non-cpu device requested
60
  os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
@@ -71,13 +72,15 @@ def select_device(device='', batch_size=0, newline=True):
71
  for i, d in enumerate(devices):
72
  p = torch.cuda.get_device_properties(i)
73
  s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
 
 
74
  else:
75
  s += 'CPU\n'
76
 
77
  if not newline:
78
  s = s.rstrip()
79
  LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
80
- return torch.device('cuda:0' if cuda else 'cpu')
81
 
82
 
83
  def time_sync():
 
54
  s = f'YOLOv5 πŸš€ {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
55
  device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
56
  cpu = device == 'cpu'
57
+ mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
58
+ if cpu or mps:
59
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
60
  elif device: # non-cpu device requested
61
  os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
 
72
  for i, d in enumerate(devices):
73
  p = torch.cuda.get_device_properties(i)
74
  s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
75
+ elif mps:
76
+ s += 'MPS\n'
77
  else:
78
  s += 'CPU\n'
79
 
80
  if not newline:
81
  s = s.rstrip()
82
  LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
83
+ return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu')
84
 
85
 
86
  def time_sync():