|
from torch import nn |
|
import torch |
|
from deepspeed.runtime.zero import GatheredParameters |
|
import deepspeed |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
|
|
def _z3_params_to_fetch(param_list): |
|
return [ |
|
p for p in param_list |
|
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE |
|
] |
|
|
|
|
|
def moving_average(model, model_ema, beta=0.9999, device=None, zero_stage=3): |
|
zero_stage_3 = (zero_stage == 3) |
|
with torch.no_grad(): |
|
for param, param_ema in zip(model.parameters(), |
|
model_ema.parameters()): |
|
|
|
params_to_fetch = _z3_params_to_fetch([param, param_ema |
|
]) if zero_stage_3 else [] |
|
should_gather_param = len(params_to_fetch) > 0 |
|
with deepspeed.zero.GatheredParameters( |
|
params_to_fetch, enabled=should_gather_param): |
|
data = param.data |
|
if device is not None: |
|
data = data.to(device) |
|
|
|
|
|
param_ema.data.copy_(torch.lerp(data, param_ema.data, beta)) |
|
|
|
|
|
|
|
def clone_zero_model(src_model, dst_model, zero_stage=0): |
|
zero_stage_3 = (zero_stage == 3) |
|
with torch.no_grad(): |
|
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()): |
|
|
|
params_to_fetch = _z3_params_to_fetch([src_param, dst_param |
|
]) if zero_stage_3 else [] |
|
should_gather_param = len(params_to_fetch) > 0 |
|
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param): |
|
dst_param.data.copy_(src_param.data) |