Spaces:
Runtime error
Runtime error
File size: 3,189 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}
ddp_factory = {'cuda': MMDistributedDataParallel}
def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel model; if device is mlu,
return a MLUDataParallel model.
Args:
model (:class:`nn.Module`): model to be parallelized.
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
dim (int): Dimension used to scatter the data. Defaults to 0.
Returns:
nn.Module: the model to be parallelized.
"""
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
dp_factory['npu'] = NPUDataParallel
torch.npu.set_device(kwargs['device_ids'][0])
torch.npu.set_compile_mode(jit_compile=False)
model = model.npu()
elif device == 'cuda':
model = model.cuda(kwargs['device_ids'][0])
elif device == 'mlu':
from mmcv.device.mlu import MLUDataParallel
dp_factory['mlu'] = MLUDataParallel
model = model.mlu()
return dp_factory[device](model, dim=dim, *args, **kwargs)
def build_ddp(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel model;
if device is mlu, return a MLUDistributedDataParallel model.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, mlu or cuda.
Returns:
:class:`nn.Module`: the module to be parallelized
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'mlu',
'npu'], 'Only available for cuda or mlu or npu devices.'
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
torch.npu.set_compile_mode(jit_compile=False)
ddp_factory['npu'] = NPUDistributedDataParallel
model = model.npu()
elif device == 'cuda':
model = model.cuda()
elif device == 'mlu':
from mmcv.device.mlu import MLUDistributedDataParallel
ddp_factory['mlu'] = MLUDistributedDataParallel
model = model.mlu()
return ddp_factory[device](model, *args, **kwargs)
def is_npu_available():
"""Returns a bool indicating if NPU is currently available."""
return hasattr(torch, 'npu') and torch.npu.is_available()
def is_mlu_available():
"""Returns a bool indicating if MLU is currently available."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
def get_device():
"""Returns an available device, cpu, cuda or mlu."""
is_device_available = {
'npu': is_npu_available(),
'cuda': torch.cuda.is_available(),
'mlu': is_mlu_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) >= 1 else 'cpu'
|