|
|
|
|
|
import itertools |
|
|
from typing import Union, Iterable, Dict, Iterator |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
|
|
from torch.distributed import ProcessGroup, group |
|
|
|
|
|
__all__ = ["average_parameters", "get_params_to_average", "average_parameters_or_parameter_groups"] |
|
|
|
|
|
def average_parameters( |
|
|
params: Iterator[torch.nn.Parameter], process_group: ProcessGroup |
|
|
): |
|
|
""" |
|
|
Averages all the given parameters. |
|
|
For allreduce efficiency, all the parameters are flattened into a contiguous buffer. |
|
|
Thus, it requires extra memory of the same size as the given parameters. |
|
|
""" |
|
|
group_to_use = process_group if process_group is not None else group.WORLD |
|
|
|
|
|
if dist._rank_not_in_group(group_to_use): |
|
|
return |
|
|
|
|
|
params_it1, params_it2 = itertools.tee(params) |
|
|
|
|
|
|
|
|
|
|
|
flat_params = torch.cat([p.data.reshape(-1) for p in params_it1]) |
|
|
flat_params /= dist.get_world_size(group_to_use) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
dist.all_reduce(flat_params, group=group_to_use) |
|
|
|
|
|
offset = 0 |
|
|
for p in params_it2: |
|
|
p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p) |
|
|
offset += p.numel() |
|
|
|
|
|
|
|
|
def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): |
|
|
""" |
|
|
Returns a list of parameters that need to average, which filters out the parameters that do not contain any gradients. |
|
|
Args: |
|
|
params: The parameters of a model or parameter groups of an optimizer. |
|
|
""" |
|
|
filtered_params = [] |
|
|
for param in params: |
|
|
if isinstance(param, torch.nn.Parameter): |
|
|
|
|
|
param_data = param |
|
|
if param_data.grad is not None: |
|
|
filtered_params.append(param_data) |
|
|
elif isinstance(param, dict): |
|
|
|
|
|
for param_data in param["params"]: |
|
|
if param_data.grad is not None: |
|
|
filtered_params.append(param_data) |
|
|
else: |
|
|
raise NotImplementedError(f"Parameter input of type {type(param)} is not supported") |
|
|
return filtered_params |
|
|
|
|
|
|
|
|
def average_parameters_or_parameter_groups(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]], process_group: ProcessGroup): |
|
|
""" |
|
|
Averages parameters of a model or parameter groups of an optimizer. |
|
|
""" |
|
|
average_parameters(iter(get_params_to_average(params)), process_group) |
|
|
|