# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os from pathlib import Path from datetime import datetime import torch import time from collections import OrderedDict from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig, # general model non-sharded, non-flattened params LocalStateDictConfig, # flattened params, usable only by FSDP # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. ) from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, save_state_dict, load_state_dict, ) from torch.distributed.checkpoint.default_planner import ( DefaultSavePlanner, DefaultLoadPlanner, ) from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType import torch.distributed.checkpoint as dist_cp import torch.distributed as dist import logging logger = logging.getLogger(__name__) def get_date_of_run(): """create date and time for file save uniqueness example: 2022-05-07-08:31:12_PM' """ date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") logger.info(f"--> current date and time of run = {date_of_run}") return date_of_run # create singleton saving policies to avoid making over and over fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) def load_model_sharded(model, rank, cfg): # torch.manual_seed(103) folder_name = ( cfg.dist_checkpoint_root_folder + "/" + cfg.dist_checkpoint_folder + "-" + cfg.model_name ) load_dir = Path.cwd() / folder_name if not load_dir.exists(): if rank == 0: logger.info(f"No sharded_state_dict checkpoint directory found...skipping") return if rank == 0: logger.info(f"loading model from model path: {load_dir} ") reader = FileSystemReader(load_dir) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): checkpoint = {"model": model.state_dict()} if rank == 0: ck = checkpoint.keys() logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") dist_cp.load_state_dict( state_dict=checkpoint, storage_reader=reader, ) if rank == 0: logger.info(f"checkpoint after load_state_dict()") ck = checkpoint.keys() logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") model.load_state_dict(checkpoint["model"]) if rank == 0: logger.info(f"Sharded state checkpoint loaded from {load_dir}") def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): """save model and optimizer via sharded_state_dict to save_dir""" folder_name = ( cfg.dist_checkpoint_root_folder + "/" + cfg.dist_checkpoint_folder + "-" + cfg.model_name ) save_dir = Path.cwd() / folder_name if rank == 0: logger.info(f"Saving model to {save_dir}") distributed_writer = dist_cp.FileSystemWriter( save_dir, ) t0 = time.perf_counter() with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = {"model": model.state_dict()} if optim is not None: state_dict["optim"] = FSDP.optim_state_dict(model, optim) dist_cp.save_state_dict( state_dict=state_dict, storage_writer=distributed_writer, planner=DefaultSavePlanner(), ) dist.barrier() t1 = time.perf_counter() if rank == 0: logger.info(f"Sharded state checkpoint saved to {save_dir}") logger.info( f"Checkpoint Time = {t1-t0:.4f}\n" ) def save_model_checkpoint( model, optimizer, rank, cfg, epoch=1, ): """saving model via rank0 cpu streaming and full_state_dict""" with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, fullstate_save_policy ): cpu_state = model.state_dict() logger.info(f"saving process: rank {rank} done w model state_dict\n") if rank == 0: logger.info(f"--> saving model ...") # create save path folder_name = ( cfg.dist_checkpoint_root_folder + "/" + cfg.dist_checkpoint_folder + "-" + cfg.model_name ) save_dir = Path.cwd() / folder_name save_dir.mkdir(parents=True, exist_ok=True) save_name = cfg.model_name + "-" + str(epoch) + ".pt" save_full_path = str(save_dir) + "/" + save_name # save model torch.save(cpu_state, save_full_path) logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"): logger.info(f"--> saving model ...") save_dir = os.path.join(cfg.output_dir, checkpoint_name) os.makedirs(save_dir, exist_ok=True) # save_full_path = os.path.join(save_dir, "model.pt") save_full_path = save_dir model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True) logger.info(f"encoder saved at {save_full_path}") def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True): logger.info(f"--> saving model ...") save_dir = os.path.join(cfg.output_dir, checkpoint_name) os.makedirs(save_dir, exist_ok=True) save_full_path = os.path.join(save_dir, "model.pt") if cfg.enable_ddp: model = model.module cpu_state = model.state_dict() if save_trainable_only: state_dict = OrderedDict() for name, para in model.named_parameters(): if para.requires_grad: state_dict[name] = cpu_state[name] else: state_dict = cpu_state torch.save(state_dict, save_full_path) logger.info(f"encoder saved at {save_full_path}") def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0): with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, fullstate_save_policy ): cpu_state = model.state_dict() logger.info(f"saving process: rank {rank} done w model state_dict\n") if rank == 0: logger.info(f"--> saving model ...") save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1)) os.makedirs(save_dir, exist_ok=True) if not cfg.freeze_llm: llm_dict = {} for key in cpu_state.keys(): if key.startswith("llm."): llm_dict[key] = cpu_state[key] model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict) logger.info(f"llm saved at {save_dir}") save_full_path = os.path.join(save_dir, "model.pt") encoder_dict = {} if not cfg.freeze_encoder: for key in cpu_state.keys(): if key.startswith("encoder."): encoder_dict[key] = cpu_state[key] for key in cpu_state.keys(): if key.startswith("encoder_projector."): encoder_dict[key] = cpu_state[key] torch.save(encoder_dict, save_full_path) logger.info(f"encoder saved at {save_full_path}") logger.info(f"model checkpoint saved for epoch {epoch+1}\n") dist.barrier() def load_model_checkpoint(model, rank, cfg): """load local checkpoint to rank0 cpu must be called * before * passing to FSDP""" if rank != 0: return # where is the checkpoint at... full_state_dict_model_path = ( Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename ) # is it present... if not full_state_dict_model_path.is_file(): logger.info( f"model checkpoint {full_state_dict_model_path} not present. Returning..." ) return model_checkpoint = torch.load(full_state_dict_model_path) # integrate into loaded model model.load_state_dict(model_checkpoint) logger.info(f"model checkpoint loaded to rank0 cpu") def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): """save optimizer state via full state dict""" logger.info(f"--> optim state call on rank {rank}\n") # pull all sharded optimizer states to rank0 cpu... optim_state = FSDP.full_optim_state_dict(model, optimizer) logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") if rank == 0: folder_name = ( cfg.dist_checkpoint_root_folder + "/" + cfg.dist_checkpoint_folder + "-" + cfg.model_name ) save_dir = Path.cwd() / folder_name save_dir.mkdir(parents=True, exist_ok=True) opt_save_name = ( "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" ) opt_save_full_path = save_dir / opt_save_name logger.info(f"--> saving optimizer state...") torch.save(optim_state, opt_save_full_path) logger.info(f"--> saved {opt_save_full_path} to disk") def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): """load an fsdp optimizer full_state checkpoint using scatter method this ensures only rank 0 loads the optimizer state dict and scatters to other ranks """ if not optimizer_checkpoint_path.is_file(): logger.info( f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " ) return full_osd = None if rank == 0: full_osd = torch.load(optimizer_checkpoint_path) # called from all ranks, though only rank0 has a valid param for full_osd sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) logger.info(f"optimizer shard loaded on rank {rank}") def load_sharded_model_single_gpu(model,model_path): reader = FileSystemReader(model_path) state_dict = { "model": model.state_dict() } dist_cp.load_state_dict( state_dict=state_dict, storage_reader= FileSystemReader(model_path), no_dist=True, ) model.load_state_dict(state_dict["model"]) logger.info(f"Sharded state checkpoint loaded from {model_path}") return model