Spaces:
Paused
Paused
File size: 6,018 Bytes
9d3cb0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
import typing
import torch
import torch.distributed as dist
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel
from ..data.datasets import ResumableDistributedSampler as DistributedSampler
from ..data.datasets import ResumableSequentialSampler as SequentialSampler
class Accelerator: # pragma: no cover
"""This class is used to prepare models and dataloaders for
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
prepare the respective objects. In the case of models, they are moved to
the appropriate GPU and SyncBatchNorm is applied to them. In the case of
dataloaders, a sampler is created and the dataloader is initialized with
that sampler.
If the world size is 1, prepare_model and prepare_dataloader are
no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
script was launched without ``torchrun``, and ``DataParallel``
will be used instead of ``DistributedDataParallel`` (not recommended), if
the world size (number of GPUs) is greater than 1.
Parameters
----------
amp : bool, optional
Whether or not to enable automatic mixed precision, by default False
"""
def __init__(self, amp: bool = False):
local_rank = os.getenv("LOCAL_RANK", None)
self.world_size = torch.cuda.device_count()
self.use_ddp = self.world_size > 1 and local_rank is not None
self.use_dp = self.world_size > 1 and local_rank is None
self.device = "cpu" if self.world_size == 0 else "cuda"
if self.use_ddp:
local_rank = int(local_rank)
dist.init_process_group(
"nccl",
init_method="env://",
world_size=self.world_size,
rank=local_rank,
)
self.local_rank = 0 if local_rank is None else local_rank
self.amp = amp
class DummyScaler:
def __init__(self):
pass
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
def prepare_model(self, model: torch.nn.Module, **kwargs):
"""Prepares model for DDP or DP. The model is moved to
the device of the correct rank.
Parameters
----------
model : torch.nn.Module
Model that is converted for DDP or DP.
Returns
-------
torch.nn.Module
Wrapped model, or original model if DDP and DP are turned off.
"""
model = model.to(self.device)
if self.use_ddp:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(
model, device_ids=[self.local_rank], **kwargs
)
elif self.use_dp:
model = DataParallel(model, **kwargs)
return model
# Automatic mixed-precision utilities
def autocast(self, *args, **kwargs):
"""Context manager for autocasting. Arguments
go to ``torch.cuda.amp.autocast``.
"""
return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
"""Backwards pass, after scaling the loss if ``amp`` is
enabled.
Parameters
----------
loss : torch.Tensor
Loss value.
"""
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
"""Steps the optimizer, using a ``scaler`` if ``amp`` is
enabled.
Parameters
----------
optimizer : torch.optim.Optimizer
Optimizer to step forward.
"""
self.scaler.step(optimizer)
def update(self):
"""Updates the scale factor."""
self.scaler.update()
def prepare_dataloader(
self, dataset: typing.Iterable, start_idx: int = None, **kwargs
):
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is
enabled.
Parameters
----------
dataset : typing.Iterable
Dataset to build Dataloader around.
start_idx : int, optional
Start index of sampler, useful if resuming from some epoch,
by default None
Returns
-------
_type_
_description_
"""
if self.use_ddp:
sampler = DistributedSampler(
dataset,
start_idx,
num_replicas=self.world_size,
rank=self.local_rank,
)
if "num_workers" in kwargs:
kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
else:
sampler = SequentialSampler(dataset, start_idx)
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
return dataloader
@staticmethod
def unwrap(model):
"""Unwraps the model if it was wrapped in DDP or DP, otherwise
just returns the model. Use this to unwrap the model returned by
:py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
"""
if hasattr(model, "module"):
return model.module
return model
|