import functools import json import operator import os from typing import Tuple import torch import torch.distributed as dist import torch.nn as nn from colossalai.booster import Booster from colossalai.cluster import DistCoordinator from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from videosys.core.comm import model_sharding def load_json(file_path: str): with open(file_path, "r") as f: return json.load(f) def save_json(data, file_path: str): with open(file_path, "w") as f: json.dump(data, f, indent=4) def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor: return tensor[: functools.reduce(operator.mul, original_shape)] def model_gathering(model: torch.nn.Module, model_shape_dict: dict): global_rank = dist.get_rank() global_size = dist.get_world_size() for name, param in model.named_parameters(): all_params = [torch.empty_like(param.data) for _ in range(global_size)] dist.all_gather(all_params, param.data, group=dist.group.WORLD) if global_rank == 0: all_params = torch.cat(all_params) param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) dist.barrier() def record_model_param_shape(model: torch.nn.Module) -> dict: param_shape = {} for name, param in model.named_parameters(): param_shape[name] = param.shape return param_shape def save( booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, global_step: int, batch_size: int, coordinator: DistCoordinator, save_dir: str, shape_dict: dict, shard_ema: bool = False, ): torch.cuda.empty_cache() global_rank = dist.get_rank() save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) booster.save_model(model, os.path.join(save_dir, "model"), shard=True) # Gather the sharded ema model before saving if shard_ema: model_gathering(ema, shape_dict) # ema is not boosted, so we don't need to use booster.save_model if global_rank == 0: torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt")) # Shard ema model when using zero2 plugin if shard_ema: model_sharding(ema) if optimizer is not None: booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) if lr_scheduler is not None: booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) running_states = { "epoch": epoch, "step": step, "global_step": global_step, "sample_start_index": step * batch_size, } if coordinator.is_master(): save_json(running_states, os.path.join(save_dir, "running_states.json")) dist.barrier() def load( booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str, ) -> Tuple[int, int, int]: booster.load_model(model, os.path.join(load_dir, "model")) # ema is not boosted, so we don't use booster.load_model ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) if optimizer is not None: booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) if lr_scheduler is not None: booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) running_states = load_json(os.path.join(load_dir, "running_states.json")) dist.barrier() torch.cuda.empty_cache() return running_states["epoch"], running_states["step"], running_states["sample_start_index"]