Open-Sora / opensora /utils /train_utils.py
kadirnar's picture
Upload 98 files
e7d5680 verified
raw
history blame
1.06 kB
from collections import OrderedDict
import torch
@torch.no_grad()
def update_ema(
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
) -> None:
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
if name == "pos_embed":
continue
if param.requires_grad == False:
continue
if not sharded:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
else:
if param.data.dtype != torch.float32:
param_id = id(param)
master_param = optimizer._param_store.working_to_master_param[param_id]
param_data = master_param.data
else:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)