mini-omni-s2s / slam_llm /utils /train_utils.py
xcczach's picture
Upload 73 files
35c1cfd verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import time
import yaml
from contextlib import nullcontext
from pathlib import Path
from pkg_resources import packaging
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from tqdm import tqdm
from transformers import LlamaTokenizer
from slam_llm.utils.checkpoint_handler import (
save_model_checkpoint,
save_model_and_optimizer_sharded,
save_optimizer_checkpoint,
save_model_checkpoint_peft,
save_model_checkpoint_peft_full_shard
)
from slam_llm.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
from slam_llm.utils.memory_utils import MemoryTrace
from slam_llm.utils.metric import compute_accuracy
import wandb
import logging
logger = logging.getLogger(__name__)
def set_tokenizer_params(tokenizer: LlamaTokenizer):
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
# Converting Bytes to Megabytes
def byte2mb(x):
return int(x / 2**20)
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, log_config, fsdp_config=None, local_rank=None, rank=None):
"""
Trains the model on the given dataloader
Args:
model: The model to be trained
train_dataloader: The dataloader containing the training data
optimizer: The optimizer used for training
lr_scheduler: The learning rate scheduler
gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
num_epochs: The number of epochs to train for
local_rank: The rank of the current node in a distributed setting
train_config: The training configuration
log_config: The logging configuration
eval_dataloader: The dataloader containing the eval data
tokenizer: tokenizer used in the eval for decoding the predicitons
Returns: results dictionary containing average training and validation perplexity and loss
"""
# Create a gradient scaler for fp16
# if train_config.use_fp16 and train_config.enable_fsdp:
# scaler = ShardedGradScaler()
# elif train_config.use_fp16 and not train_config.enable_fsdp:
# scaler = torch.cuda.amp.GradScaler()
if train_config.use_fp16:
scaler = torch.cuda.amp.GradScaler()
if train_config.enable_fsdp:
scaler = ShardedGradScaler()
if train_config.enable_fsdp or train_config.enable_ddp:
world_size = int(os.environ["WORLD_SIZE"])
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
train_prep = []
train_loss = []
train_acc = []
val_prep = []
val_loss =[]
val_acc = []
epoch_times = []
checkpoint_times = []
results = {}
best_val_loss = float("inf")
best_val_acc = 0.0
for epoch in range(train_config.num_epochs):
epoch_start_time = time.perf_counter()
with MemoryTrace() as memtrace: # track the memory usage
model.train()
total_loss = 0.0
total_acc = 0.0
total_length = len(train_dataloader)//gradient_accumulation_steps
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
for step, batch in enumerate(train_dataloader):
for key in batch.keys():
if train_config.enable_fsdp or train_config.enable_ddp:
batch[key] = batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key]
if isinstance(batch[key], dict):
for k2 in batch[key].keys():
batch[key][k2] = batch[key][k2].to(local_rank) if isinstance(batch[key][k2], torch.Tensor) else batch[key][k2]
else:
batch[key] = batch[key].to('cuda:0') if isinstance(batch[key], torch.Tensor) else batch[key]
if isinstance(batch[key], dict):
for k2 in batch[key].keys():
batch[key][k2] = batch[key][k2].to('cuda:0') if isinstance(batch[key][k2], torch.Tensor) else batch[key][k2]
with autocast():
outputs, *rest = model(**batch)
acc = rest[0] if rest else -1
audio_acc = rest[1] if rest else -1 # seven layers of audio acc
layer_loss = rest[2] if rest else -1 # eight layers of loss (seven audio and one text)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
layer_loss = [l / gradient_accumulation_steps for l in layer_loss]
acc = acc / gradient_accumulation_steps
audio_acc = [a / gradient_accumulation_steps for a in audio_acc]
if log_config.use_wandb and step % log_config.log_interval == 0:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_text_accuracy":acc}, step=(epoch * total_length + step))
for layer, acc in enumerate(audio_acc):
wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
for layer, l in enumerate(layer_loss[:-1]):
wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
else:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
for layer, acc in enumerate(audio_acc):
wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
for layer, l in enumerate(layer_loss[:-1]):
wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
total_loss += loss.detach().float()
total_acc += acc
if train_config.use_fp16:
# if fp16 is enabled, use gradient scaler to handle gradient update
scaler.scale(loss).backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
scaler.step(optimizer)
scaler.update()
if lr_scheduler is not None:
lr_scheduler.step()
current_lr = lr_scheduler.get_last_lr()[0]
else:
current_lr = optimizer.param_groups[0]["lr"]
if current_lr == 0:
break
if log_config.use_wandb and step % log_config.log_interval == 0:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
else:
wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
optimizer.zero_grad()
pbar.update(1)
else:
# regular backpropagation when fp16 is not used
loss.backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
current_lr = lr_scheduler.get_last_lr()[0]
else:
current_lr = optimizer.param_groups[0]["lr"]
if current_lr == 0:
break
if log_config.use_wandb and step % log_config.log_interval == 0:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
else:
wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
optimizer.zero_grad()
pbar.update(1)
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})")
if (epoch * total_length + step + 1) % train_config.validation_interval == 0 and train_config.run_validation:
eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
eval_epoch_acc = rest[0] if rest else -1
checkpoint_start_time = time.perf_counter()
if train_config.save_model and (eval_epoch_loss < best_val_loss):
checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}"
if train_config.enable_fsdp or train_config.enable_ddp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
logger.info(f"we are about to save the PEFT modules")
else:
logger.info(f"we are about to save the PEFT modules")
if train_config.enable_fsdp:
if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD:
save_model_checkpoint_peft_full_shard(
model, optimizer, rank, train_config, epoch=epoch
)
elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
else:
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
logger.info(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
logger.info(f"PEFT modules are saved in {train_config.output_dir} directory")
elif not train_config.use_peft and train_config.freeze_llm:
logger.info(f"llm is frozen, we are about to save other parts.")
if train_config.enable_fsdp:
if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD:
save_model_checkpoint_peft_full_shard(
model, optimizer, rank, train_config, epoch=epoch
)
elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
else:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
else:
if train_config.enable_fsdp:
if getattr(StateDictType, fsdp_config.checkpoint_type) == StateDictType.FULL_STATE_DICT:
save_model_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
elif getattr(StateDictType, fsdp_config.checkpoint_type) == StateDictType.SHARDED_STATE_DICT:
logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
logger.info("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer:
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
logger.info("=====================================================")
if train_config.save_optimizer:
save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
logger.info("=====================================================")
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
else:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
if train_config.enable_fsdp or train_config.enable_ddp:
dist.barrier()
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
checkpoint_times.append(checkpoint_end_time)
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
else:
logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
val_loss.append(eval_epoch_loss)
val_prep.append(eval_ppl)
if rest:
if eval_epoch_acc > best_val_acc:
best_val_acc = eval_epoch_acc
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}")
else:
logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}")
val_acc.append(rest[0])
else:
val_acc.append(-1)
if log_config.use_wandb:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc})
else:
wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc})
if train_config.run_test_during_validation:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
logger.info("=====================================")
logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:")
with autocast():
logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt))
logger.info("=====================================")
dist.barrier()
else:
logger.info("=====================================")
logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:")
with autocast():
logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt))
logger.info("=====================================")
pbar.close()
epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if torch.cuda.device_count() > 1 and (train_config.enable_fsdp or train_config.enable_ddp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(total_acc, op=dist.ReduceOp.SUM)
train_epoch_loss = total_loss / len(train_dataloader)
train_epoch_acc = total_acc / len(train_dataloader)
if train_config.enable_fsdp or train_config.enable_ddp:
train_epoch_loss = train_epoch_loss/world_size
train_epoch_acc = train_epoch_acc/world_size
train_perplexity = torch.exp(train_epoch_loss)
train_prep.append(train_perplexity)
train_loss.append(train_epoch_loss)
train_acc.append(train_epoch_acc)
if log_config.use_wandb:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})
else:
wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
else:
logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
if train_config.enable_fsdp:
if rank==0:
logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB")
logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
else:
logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB")
logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
# Update the learning rate as needed
# lr_scheduler.step()
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
avg_train_prep = sum(train_prep)/len(train_prep)
avg_train_loss = sum(train_loss)/len(train_loss)
avg_train_acc = sum(train_acc)/len(train_acc)
if train_config.run_validation:
avg_eval_prep = sum(val_prep)/len(val_prep)
avg_eval_loss = sum(val_loss)/len(val_loss)
avg_eval_acc = sum(val_acc)/len(val_acc)
results['avg_train_prep'] = avg_train_prep
results['avg_train_loss'] = avg_train_loss
results['avg_train_acc'] = avg_train_acc
if train_config.run_validation:
results['avg_eval_prep'] = avg_eval_prep
results['avg_eval_loss'] = avg_eval_loss
results['avg_eval_acc'] = avg_eval_acc
results["avg_epoch_time"] = avg_epoch_time
results["avg_checkpoint_time"] = avg_checkpoint_time
#saving the training params including fsdp setting for reference.
# if (train_config.enable_fsdp or train_config.enable_ddp)and not train_config.use_peft:
# save_train_params(train_config, fsdp_config, rank)
return results
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
"""
Evaluates the model on the given dataloader
Args:
model: The model to evaluate
eval_dataloader: The dataloader containing the evaluation data
local_rank: The rank of the current node in a distributed setting
tokenizer: The tokenizer used to decode predictions
Returns: eval_ppl, eval_epoch_loss
"""
if train_config.enable_fsdp or train_config.enable_ddp:
world_size = int(os.environ["WORLD_SIZE"])
model.eval()
eval_preds = []
eval_loss = 0.0 # Initialize evaluation loss
eval_acc = 0.0
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm
with MemoryTrace() as memtrace:
total_length = len(eval_dataloader)
pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True)
for step, batch in enumerate(eval_dataloader):
for key in batch.keys():
if train_config.enable_fsdp or train_config.enable_ddp:
batch[key] = batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key]
else:
batch[key] = batch[key].to('cuda:0') if isinstance(batch[key], torch.Tensor) else batch[key]
# Ensure no gradients are computed for this scope to save memory
with torch.no_grad():
# Forward pass and compute loss
with autocast(): # (Fix:MZY): fix expected scalar type mismatch in norm
outputs, *rest = model(**batch)
acc = rest[0] if rest else -1
loss = outputs.loss
eval_loss += loss.detach().float()
eval_acc += acc
# Decode predictions and add to evaluation predictions list
try:
preds = torch.argmax(outputs.logits, -1)
eval_preds.extend(
tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
)
except Exception:
pass # vallex does not need to show it's result (we can't view any thing from abstract acoustic token)
pbar.update(1)
pbar.set_description(f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}")
# If there's more than one CUDA device, reduce evaluation loss across all devices
if torch.cuda.device_count() > 1 and train_config.enable_fsdp or train_config.enable_ddp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM)
# Compute average loss and perplexity
eval_epoch_loss = eval_loss / len(eval_dataloader)
eval_epoch_acc = eval_acc / len(eval_dataloader)
if train_config.enable_fsdp or train_config.enable_ddp:
eval_epoch_loss = eval_epoch_loss/world_size
eval_epoch_acc = eval_epoch_acc/world_size
eval_ppl = torch.exp(eval_epoch_loss)
# Print evaluation metrics
if train_config.enable_fsdp or train_config.enable_ddp:
if local_rank==0:
logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}")
else:
logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}")
return eval_ppl, eval_epoch_loss, eval_epoch_acc
def freeze_transformer_layers(model, num_layer):
for i, layer in enumerate(model.model.layers):
if i < num_layer:
for param in layer.parameters():
param.requires_grad = False
def check_frozen_layers_peft_model(model):
for i, layer in enumerate(model.base_model.model.model.layers):
for name, param in layer.named_parameters():
logger.info(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
def setup():
"""Initialize the process group for distributed training"""
dist.init_process_group("nccl")
def setup_environ_flags(rank):
"""Set environment flags for debugging purposes"""
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
# Note this is only availble in PyTorch Nighlies (as of July 30 2023)
# os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
if rank == 0:
logger.info(f"--> Running with torch dist debug set to detail")
def cleanup():
"""Clean up the process group after training"""
dist.destroy_process_group()
def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
logger.info(f"Clearing GPU cache for all ranks")
torch.cuda.empty_cache()
def get_parameter_dtypes(model):
"""Get the data types of model parameters"""
parameter_dtypes = {}
for name, parameter in model.named_parameters():
parameter_dtypes[name] = parameter.dtype
return parameter_dtypes
def print_model_size(model, config, rank: int = 0) -> None:
"""
log model name, the number of trainable parameters and initialization time.
Args:
model: The PyTorch model.
model_name (str): Name of the model.
init_time_start (float): Initialization start time.
init_time_end (float): Initialization end time.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
logger.info(f"--> Model {config.model_name}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"--> {config.model_name} has {total_params / 1e6} Million params\n")
def print_module_size(module, module_name, rank: int = 0) -> None:
"""
Print module name, the number of trainable parameters and initialization time.
Args:
module: The PyTorch module.
module_name (str): Name of the model.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
logger.info(f"--> Module {module_name}")
total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
logger.info(f"--> {module_name} has {total_params / 1e6} Million params\n")
def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
verify_bfloat_support = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
mixed_precision_policy = None
wrapping_policy = None
# Mixed precision
if cfg.mixed_precision:
bf16_ready = verify_bfloat_support
if bf16_ready and not cfg.use_fp16:
mixed_precision_policy = bfSixteen_mixed
if rank == 0:
logger.info(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
elif cfg.use_fp16:
mixed_precision_policy = fpSixteen
if rank == 0:
logger.info(f"FP16 enabled")
else:
logger.info(f"bFloat16 support not present. Using FP32, and not mixed precision")
wrapping_policy = get_llama_wrapper()
return mixed_precision_policy, wrapping_policy
def save_train_params(train_config, fsdp_config, rank):
"""
This function saves the train_config and FSDP config into a train_params.yaml.
This will be used by converter script in the inference folder to fetch the HF model name or path.
It also would be hepful as a log for future references.
"""
# Convert the train_config and fsdp_config objects to dictionaries,
# converting all values to strings to ensure they can be serialized into a YAML file
train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
# Merge the two dictionaries into one
train_params_dict = {**train_config_dict, **fsdp_config_dict}
# Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
folder_name = (
train_config.dist_checkpoint_root_folder
+ "/"
+ train_config.dist_checkpoint_folder
+ "-"
+ train_config.model_name
)
save_dir = Path.cwd() / folder_name
# If the directory does not exist, create it
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Convert the dictionary to a YAML string
config_yaml = yaml.dump(train_params_dict, indent=4)
file_name = os.path.join(save_dir,'train_params.yaml')
# Check if there's a directory with the same name as the file
if os.path.isdir(file_name):
logger.info(f"Error: {file_name} is a directory, not a file.")
else:
# Write the YAML string to the file
with open(file_name, 'w') as f:
f.write(config_yaml)
if rank==0:
logger.info(f"training params are saved in {file_name}")