glenn-jocher
commited on
Commit
β’
c215878
1
Parent(s):
27911dc
YOLOv5 Apple Metal Performance Shader (MPS) support (#7878)
Browse files* Apple Metal Performance Shader (MPS) device support
Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
Should work with Apple M1 devices with PyTorch nightly installed with command `--device mps`. Usage examples:
```bash
python train.py --device mps
python detect.py --device mps
python val.py --device mps
```
* Update device strategy to fix MPS issue
- export.py +1 -1
- models/common.py +1 -1
- models/experimental.py +3 -3
- models/tf.py +1 -1
- utils/torch_utils.py +5 -2
export.py
CHANGED
@@ -486,7 +486,7 @@ def run(
|
|
486 |
if half:
|
487 |
assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0'
|
488 |
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
489 |
-
model = attempt_load(weights,
|
490 |
nc, names = model.nc, model.names # number of classes, class names
|
491 |
|
492 |
# Checks
|
|
|
486 |
if half:
|
487 |
assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0'
|
488 |
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
489 |
+
model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
|
490 |
nc, names = model.nc, model.names # number of classes, class names
|
491 |
|
492 |
# Checks
|
models/common.py
CHANGED
@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module):
|
|
331 |
names = yaml.safe_load(f)['names']
|
332 |
|
333 |
if pt: # PyTorch
|
334 |
-
model = attempt_load(weights if isinstance(weights, list) else w,
|
335 |
stride = max(int(model.stride.max()), 32) # model stride
|
336 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
337 |
model.half() if fp16 else model.float()
|
|
|
331 |
names = yaml.safe_load(f)['names']
|
332 |
|
333 |
if pt: # PyTorch
|
334 |
+
model = attempt_load(weights if isinstance(weights, list) else w, device=device)
|
335 |
stride = max(int(model.stride.max()), 32) # model stride
|
336 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
337 |
model.half() if fp16 else model.float()
|
models/experimental.py
CHANGED
@@ -71,14 +71,14 @@ class Ensemble(nn.ModuleList):
|
|
71 |
return y, None # inference, train output
|
72 |
|
73 |
|
74 |
-
def attempt_load(weights,
|
75 |
from models.yolo import Detect, Model
|
76 |
|
77 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
78 |
model = Ensemble()
|
79 |
for w in weights if isinstance(weights, list) else [weights]:
|
80 |
-
ckpt = torch.load(attempt_download(w)
|
81 |
-
ckpt = (ckpt.get('ema') or ckpt['model']).float() # FP32 model
|
82 |
model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
|
83 |
|
84 |
# Compatibility updates
|
|
|
71 |
return y, None # inference, train output
|
72 |
|
73 |
|
74 |
+
def attempt_load(weights, device=None, inplace=True, fuse=True):
|
75 |
from models.yolo import Detect, Model
|
76 |
|
77 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
78 |
model = Ensemble()
|
79 |
for w in weights if isinstance(weights, list) else [weights]:
|
80 |
+
ckpt = torch.load(attempt_download(w))
|
81 |
+
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
82 |
model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
|
83 |
|
84 |
# Compatibility updates
|
models/tf.py
CHANGED
@@ -536,7 +536,7 @@ def run(
|
|
536 |
):
|
537 |
# PyTorch model
|
538 |
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
539 |
-
model = attempt_load(weights,
|
540 |
_ = model(im) # inference
|
541 |
model.info()
|
542 |
|
|
|
536 |
):
|
537 |
# PyTorch model
|
538 |
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
539 |
+
model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
|
540 |
_ = model(im) # inference
|
541 |
model.info()
|
542 |
|
utils/torch_utils.py
CHANGED
@@ -54,7 +54,8 @@ def select_device(device='', batch_size=0, newline=True):
|
|
54 |
s = f'YOLOv5 π {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
55 |
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
56 |
cpu = device == 'cpu'
|
57 |
-
|
|
|
58 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
59 |
elif device: # non-cpu device requested
|
60 |
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
@@ -71,13 +72,15 @@ def select_device(device='', batch_size=0, newline=True):
|
|
71 |
for i, d in enumerate(devices):
|
72 |
p = torch.cuda.get_device_properties(i)
|
73 |
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
|
|
|
|
74 |
else:
|
75 |
s += 'CPU\n'
|
76 |
|
77 |
if not newline:
|
78 |
s = s.rstrip()
|
79 |
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
80 |
-
return torch.device('cuda:0' if cuda else 'cpu')
|
81 |
|
82 |
|
83 |
def time_sync():
|
|
|
54 |
s = f'YOLOv5 π {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
55 |
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
56 |
cpu = device == 'cpu'
|
57 |
+
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
58 |
+
if cpu or mps:
|
59 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
60 |
elif device: # non-cpu device requested
|
61 |
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
|
|
72 |
for i, d in enumerate(devices):
|
73 |
p = torch.cuda.get_device_properties(i)
|
74 |
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
75 |
+
elif mps:
|
76 |
+
s += 'MPS\n'
|
77 |
else:
|
78 |
s += 'CPU\n'
|
79 |
|
80 |
if not newline:
|
81 |
s = s.rstrip()
|
82 |
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
83 |
+
return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu')
|
84 |
|
85 |
|
86 |
def time_sync():
|