import time import os from tqdm import tqdm import sys from copy import deepcopy from contextlib import suppress import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ( FullStateDictConfig, StateDictType, ) from torch.distributed.fsdp.api import FullOptimStateDictConfig from einops import rearrange class Dict2Class: def __init__(self, data_dict): for key, value in data_dict.items(): setattr(self, key, value) class SysLogger(object): def __init__(self, filename="../log/log.log"): self.terminal = sys.stdout self.log = open(filename, "a") def write(self, message): self.terminal.write(message+'\n') self.log.write(message) def get_cast_dtype(precision: str): cast_dtype = None if precision == "bf16": cast_dtype = torch.bfloat16 elif precision == "fp16": cast_dtype = torch.float16 return cast_dtype def get_mp_policy_dtype(precision: str): if "bfloat16" in precision or "bf16" in precision: return torch.bfloat16 elif precision == "fp16": return torch.float16 else: return torch.float32 def get_autocast(precision, cache_enabled=True): if precision == "amp": return torch.cuda.amp.autocast(cache_enabled=cache_enabled) elif precision == "amp_bfloat16" or precision == "amp_bf16": return lambda: torch.cuda.amp.autocast( dtype=torch.bfloat16, cache_enabled=cache_enabled ) else: return suppress def train_one_epoch( args, model, epoch, trainloader, tokenizer, optimizer, lr_scheduler, device_id, tb ): # setup loaders num_batches_per_epoch = len(trainloader) total_training_steps = num_batches_per_epoch * args.num_epochs print('num_batches_per_epoch={}, total_training_steps={}'.format(num_batches_per_epoch, total_training_steps)) autocast = get_autocast( args.precision, cache_enabled=(not args.fsdp) ) # if fsdp, disable cache to save memory cast_dtype = get_cast_dtype(args.precision) # setup model media_token_id = tokenizer("