BioM3 / Stage3_source /DSEma.py
Niksa Praljak
Add scripts for ProteoScribe Sampling
c865888
raw
history blame
1.99 kB
from torch import nn
import torch
from deepspeed.runtime.zero import GatheredParameters
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
def _z3_params_to_fetch(param_list):
return [
p for p in param_list
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
def moving_average(model, model_ema, beta=0.9999, device=None, zero_stage=3):
zero_stage_3 = (zero_stage == 3)
with torch.no_grad():
for param, param_ema in zip(model.parameters(),
model_ema.parameters()):
# TODO: use prefiltering for efficiency
params_to_fetch = _z3_params_to_fetch([param, param_ema
]) if zero_stage_3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(
params_to_fetch, enabled=should_gather_param):
data = param.data
if device is not None:
data = data.to(device)
#print('real model',data.shape, data)
#print('ema model',param_ema.shape, param_ema.data)
param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
#print('after ema copy',param_ema.shape, param_ema.data)
def clone_zero_model(src_model, dst_model, zero_stage=0):
zero_stage_3 = (zero_stage == 3)
with torch.no_grad():
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
# TODO: use prefiltering for efficiency
params_to_fetch = _z3_params_to_fetch([src_param, dst_param
]) if zero_stage_3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
dst_param.data.copy_(src_param.data)