Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
import os | |
from pathlib import Path | |
from datetime import datetime | |
import torch | |
import time | |
from collections import OrderedDict | |
from torch.distributed.fsdp import ( | |
FullyShardedDataParallel as FSDP, | |
StateDictType, | |
FullStateDictConfig, # general model non-sharded, non-flattened params | |
LocalStateDictConfig, # flattened params, usable only by FSDP | |
# ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. | |
) | |
from torch.distributed.checkpoint import ( | |
FileSystemReader, | |
FileSystemWriter, | |
save_state_dict, | |
load_state_dict, | |
) | |
from torch.distributed.checkpoint.default_planner import ( | |
DefaultSavePlanner, | |
DefaultLoadPlanner, | |
) | |
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | |
import torch.distributed.checkpoint as dist_cp | |
import torch.distributed as dist | |
import logging | |
logger = logging.getLogger(__name__) | |
def get_date_of_run(): | |
"""create date and time for file save uniqueness | |
example: 2022-05-07-08:31:12_PM' | |
""" | |
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") | |
logger.info(f"--> current date and time of run = {date_of_run}") | |
return date_of_run | |
# create singleton saving policies to avoid making over and over | |
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
def load_model_sharded(model, rank, cfg): | |
# torch.manual_seed(103) | |
folder_name = ( | |
cfg.dist_checkpoint_root_folder | |
+ "/" | |
+ cfg.dist_checkpoint_folder | |
+ "-" | |
+ cfg.model_name | |
) | |
load_dir = Path.cwd() / folder_name | |
if not load_dir.exists(): | |
if rank == 0: | |
logger.info(f"No sharded_state_dict checkpoint directory found...skipping") | |
return | |
if rank == 0: | |
logger.info(f"loading model from model path: {load_dir} ") | |
reader = FileSystemReader(load_dir) | |
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): | |
checkpoint = {"model": model.state_dict()} | |
if rank == 0: | |
ck = checkpoint.keys() | |
logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") | |
dist_cp.load_state_dict( | |
state_dict=checkpoint, | |
storage_reader=reader, | |
) | |
if rank == 0: | |
logger.info(f"checkpoint after load_state_dict()") | |
ck = checkpoint.keys() | |
logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") | |
model.load_state_dict(checkpoint["model"]) | |
if rank == 0: | |
logger.info(f"Sharded state checkpoint loaded from {load_dir}") | |
def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): | |
"""save model and optimizer via sharded_state_dict to save_dir""" | |
folder_name = ( | |
cfg.dist_checkpoint_root_folder | |
+ "/" | |
+ cfg.dist_checkpoint_folder | |
+ "-" | |
+ cfg.model_name | |
) | |
save_dir = Path.cwd() / folder_name | |
if rank == 0: | |
logger.info(f"Saving model to {save_dir}") | |
distributed_writer = dist_cp.FileSystemWriter( | |
save_dir, | |
) | |
t0 = time.perf_counter() | |
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): | |
state_dict = {"model": model.state_dict()} | |
if optim is not None: | |
state_dict["optim"] = FSDP.optim_state_dict(model, optim) | |
dist_cp.save_state_dict( | |
state_dict=state_dict, | |
storage_writer=distributed_writer, | |
planner=DefaultSavePlanner(), | |
) | |
dist.barrier() | |
t1 = time.perf_counter() | |
if rank == 0: | |
logger.info(f"Sharded state checkpoint saved to {save_dir}") | |
logger.info( | |
f"Checkpoint Time = {t1-t0:.4f}\n" | |
) | |
def save_model_checkpoint( | |
model, | |
optimizer, | |
rank, | |
cfg, | |
epoch=1, | |
): | |
"""saving model via rank0 cpu streaming and full_state_dict""" | |
with FSDP.state_dict_type( | |
model, StateDictType.FULL_STATE_DICT, fullstate_save_policy | |
): | |
cpu_state = model.state_dict() | |
logger.info(f"saving process: rank {rank} done w model state_dict\n") | |
if rank == 0: | |
logger.info(f"--> saving model ...") | |
# create save path | |
folder_name = ( | |
cfg.dist_checkpoint_root_folder | |
+ "/" | |
+ cfg.dist_checkpoint_folder | |
+ "-" | |
+ cfg.model_name | |
) | |
save_dir = Path.cwd() / folder_name | |
save_dir.mkdir(parents=True, exist_ok=True) | |
save_name = cfg.model_name + "-" + str(epoch) + ".pt" | |
save_full_path = str(save_dir) + "/" + save_name | |
# save model | |
torch.save(cpu_state, save_full_path) | |
logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") | |
def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"): | |
logger.info(f"--> saving model ...") | |
save_dir = os.path.join(cfg.output_dir, checkpoint_name) | |
os.makedirs(save_dir, exist_ok=True) | |
# save_full_path = os.path.join(save_dir, "model.pt") | |
save_full_path = save_dir | |
model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True) | |
logger.info(f"encoder saved at {save_full_path}") | |
def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True): | |
logger.info(f"--> saving model ...") | |
save_dir = os.path.join(cfg.output_dir, checkpoint_name) | |
os.makedirs(save_dir, exist_ok=True) | |
save_full_path = os.path.join(save_dir, "model.pt") | |
if cfg.enable_ddp: | |
model = model.module | |
cpu_state = model.state_dict() | |
if save_trainable_only: | |
state_dict = OrderedDict() | |
for name, para in model.named_parameters(): | |
if para.requires_grad: | |
state_dict[name] = cpu_state[name] | |
else: | |
state_dict = cpu_state | |
torch.save(state_dict, save_full_path) | |
logger.info(f"encoder saved at {save_full_path}") | |
def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0): | |
with FSDP.state_dict_type( | |
model, StateDictType.FULL_STATE_DICT, fullstate_save_policy | |
): | |
cpu_state = model.state_dict() | |
logger.info(f"saving process: rank {rank} done w model state_dict\n") | |
if rank == 0: | |
logger.info(f"--> saving model ...") | |
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1)) | |
os.makedirs(save_dir, exist_ok=True) | |
if not cfg.freeze_llm: | |
llm_dict = {} | |
for key in cpu_state.keys(): | |
if key.startswith("llm."): | |
llm_dict[key] = cpu_state[key] | |
model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict) | |
logger.info(f"llm saved at {save_dir}") | |
save_full_path = os.path.join(save_dir, "model.pt") | |
encoder_dict = {} | |
if not cfg.freeze_encoder: | |
for key in cpu_state.keys(): | |
if key.startswith("encoder."): | |
encoder_dict[key] = cpu_state[key] | |
for key in cpu_state.keys(): | |
if key.startswith("encoder_projector."): | |
encoder_dict[key] = cpu_state[key] | |
torch.save(encoder_dict, save_full_path) | |
logger.info(f"encoder saved at {save_full_path}") | |
logger.info(f"model checkpoint saved for epoch {epoch+1}\n") | |
dist.barrier() | |
def load_model_checkpoint(model, rank, cfg): | |
"""load local checkpoint to rank0 cpu | |
must be called * before * passing to FSDP""" | |
if rank != 0: | |
return | |
# where is the checkpoint at... | |
full_state_dict_model_path = ( | |
Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename | |
) | |
# is it present... | |
if not full_state_dict_model_path.is_file(): | |
logger.info( | |
f"model checkpoint {full_state_dict_model_path} not present. Returning..." | |
) | |
return | |
model_checkpoint = torch.load(full_state_dict_model_path) | |
# integrate into loaded model | |
model.load_state_dict(model_checkpoint) | |
logger.info(f"model checkpoint loaded to rank0 cpu") | |
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): | |
"""save optimizer state via full state dict""" | |
logger.info(f"--> optim state call on rank {rank}\n") | |
# pull all sharded optimizer states to rank0 cpu... | |
optim_state = FSDP.full_optim_state_dict(model, optimizer) | |
logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") | |
if rank == 0: | |
folder_name = ( | |
cfg.dist_checkpoint_root_folder | |
+ "/" | |
+ cfg.dist_checkpoint_folder | |
+ "-" | |
+ cfg.model_name | |
) | |
save_dir = Path.cwd() / folder_name | |
save_dir.mkdir(parents=True, exist_ok=True) | |
opt_save_name = ( | |
"optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" | |
) | |
opt_save_full_path = save_dir / opt_save_name | |
logger.info(f"--> saving optimizer state...") | |
torch.save(optim_state, opt_save_full_path) | |
logger.info(f"--> saved {opt_save_full_path} to disk") | |
def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): | |
"""load an fsdp optimizer full_state checkpoint using scatter method | |
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks | |
""" | |
if not optimizer_checkpoint_path.is_file(): | |
logger.info( | |
f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " | |
) | |
return | |
full_osd = None | |
if rank == 0: | |
full_osd = torch.load(optimizer_checkpoint_path) | |
# called from all ranks, though only rank0 has a valid param for full_osd | |
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) | |
logger.info(f"optimizer shard loaded on rank {rank}") | |
def load_sharded_model_single_gpu(model,model_path): | |
reader = FileSystemReader(model_path) | |
state_dict = { | |
"model": model.state_dict() | |
} | |
dist_cp.load_state_dict( | |
state_dict=state_dict, | |
storage_reader= FileSystemReader(model_path), | |
no_dist=True, | |
) | |
model.load_state_dict(state_dict["model"]) | |
logger.info(f"Sharded state checkpoint loaded from {model_path}") | |
return model | |