mini-omni-s2s / slam_llm /utils /checkpoint_handler.py
xcczach's picture
Upload 73 files
35c1cfd verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
from pathlib import Path
from datetime import datetime
import torch
import time
from collections import OrderedDict
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig, # general model non-sharded, non-flattened params
LocalStateDictConfig, # flattened params, usable only by FSDP
# ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
)
from torch.distributed.checkpoint import (
FileSystemReader,
FileSystemWriter,
save_state_dict,
load_state_dict,
)
from torch.distributed.checkpoint.default_planner import (
DefaultSavePlanner,
DefaultLoadPlanner,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed.checkpoint as dist_cp
import torch.distributed as dist
import logging
logger = logging.getLogger(__name__)
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
logger.info(f"--> current date and time of run = {date_of_run}")
return date_of_run
# create singleton saving policies to avoid making over and over
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
def load_model_sharded(model, rank, cfg):
# torch.manual_seed(103)
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
load_dir = Path.cwd() / folder_name
if not load_dir.exists():
if rank == 0:
logger.info(f"No sharded_state_dict checkpoint directory found...skipping")
return
if rank == 0:
logger.info(f"loading model from model path: {load_dir} ")
reader = FileSystemReader(load_dir)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
checkpoint = {"model": model.state_dict()}
if rank == 0:
ck = checkpoint.keys()
logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=reader,
)
if rank == 0:
logger.info(f"checkpoint after load_state_dict()")
ck = checkpoint.keys()
logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
model.load_state_dict(checkpoint["model"])
if rank == 0:
logger.info(f"Sharded state checkpoint loaded from {load_dir}")
def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
"""save model and optimizer via sharded_state_dict to save_dir"""
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
if rank == 0:
logger.info(f"Saving model to {save_dir}")
distributed_writer = dist_cp.FileSystemWriter(
save_dir,
)
t0 = time.perf_counter()
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {"model": model.state_dict()}
if optim is not None:
state_dict["optim"] = FSDP.optim_state_dict(model, optim)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=distributed_writer,
planner=DefaultSavePlanner(),
)
dist.barrier()
t1 = time.perf_counter()
if rank == 0:
logger.info(f"Sharded state checkpoint saved to {save_dir}")
logger.info(
f"Checkpoint Time = {t1-t0:.4f}\n"
)
def save_model_checkpoint(
model,
optimizer,
rank,
cfg,
epoch=1,
):
"""saving model via rank0 cpu streaming and full_state_dict"""
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
):
cpu_state = model.state_dict()
logger.info(f"saving process: rank {rank} done w model state_dict\n")
if rank == 0:
logger.info(f"--> saving model ...")
# create save path
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
save_name = cfg.model_name + "-" + str(epoch) + ".pt"
save_full_path = str(save_dir) + "/" + save_name
# save model
torch.save(cpu_state, save_full_path)
logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"):
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, checkpoint_name)
os.makedirs(save_dir, exist_ok=True)
# save_full_path = os.path.join(save_dir, "model.pt")
save_full_path = save_dir
model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True)
logger.info(f"encoder saved at {save_full_path}")
def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True):
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, checkpoint_name)
os.makedirs(save_dir, exist_ok=True)
save_full_path = os.path.join(save_dir, "model.pt")
if cfg.enable_ddp:
model = model.module
cpu_state = model.state_dict()
if save_trainable_only:
state_dict = OrderedDict()
for name, para in model.named_parameters():
if para.requires_grad:
state_dict[name] = cpu_state[name]
else:
state_dict = cpu_state
torch.save(state_dict, save_full_path)
logger.info(f"encoder saved at {save_full_path}")
def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0):
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
):
cpu_state = model.state_dict()
logger.info(f"saving process: rank {rank} done w model state_dict\n")
if rank == 0:
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1))
os.makedirs(save_dir, exist_ok=True)
if not cfg.freeze_llm:
llm_dict = {}
for key in cpu_state.keys():
if key.startswith("llm."):
llm_dict[key] = cpu_state[key]
model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict)
logger.info(f"llm saved at {save_dir}")
save_full_path = os.path.join(save_dir, "model.pt")
encoder_dict = {}
if not cfg.freeze_encoder:
for key in cpu_state.keys():
if key.startswith("encoder."):
encoder_dict[key] = cpu_state[key]
for key in cpu_state.keys():
if key.startswith("encoder_projector."):
encoder_dict[key] = cpu_state[key]
torch.save(encoder_dict, save_full_path)
logger.info(f"encoder saved at {save_full_path}")
logger.info(f"model checkpoint saved for epoch {epoch+1}\n")
dist.barrier()
def load_model_checkpoint(model, rank, cfg):
"""load local checkpoint to rank0 cpu
must be called * before * passing to FSDP"""
if rank != 0:
return
# where is the checkpoint at...
full_state_dict_model_path = (
Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
)
# is it present...
if not full_state_dict_model_path.is_file():
logger.info(
f"model checkpoint {full_state_dict_model_path} not present. Returning..."
)
return
model_checkpoint = torch.load(full_state_dict_model_path)
# integrate into loaded model
model.load_state_dict(model_checkpoint)
logger.info(f"model checkpoint loaded to rank0 cpu")
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
"""save optimizer state via full state dict"""
logger.info(f"--> optim state call on rank {rank}\n")
# pull all sharded optimizer states to rank0 cpu...
optim_state = FSDP.full_optim_state_dict(model, optimizer)
logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
if rank == 0:
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
opt_save_name = (
"optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
)
opt_save_full_path = save_dir / opt_save_name
logger.info(f"--> saving optimizer state...")
torch.save(optim_state, opt_save_full_path)
logger.info(f"--> saved {opt_save_full_path} to disk")
def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
"""load an fsdp optimizer full_state checkpoint using scatter method
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
"""
if not optimizer_checkpoint_path.is_file():
logger.info(
f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
)
return
full_osd = None
if rank == 0:
full_osd = torch.load(optimizer_checkpoint_path)
# called from all ranks, though only rank0 has a valid param for full_osd
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
logger.info(f"optimizer shard loaded on rank {rank}")
def load_sharded_model_single_gpu(model,model_path):
reader = FileSystemReader(model_path)
state_dict = {
"model": model.state_dict()
}
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader= FileSystemReader(model_path),
no_dist=True,
)
model.load_state_dict(state_dict["model"])
logger.info(f"Sharded state checkpoint loaded from {model_path}")
return model