import itertools import logging import math import time from contextlib import nullcontext import numpy as np import torch import torch.distributed as dist from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.fsdp import FullyShardedDataParallel as FSDP try: from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss from megablocks.layers.arguments import Arguments as MoEArgs except ImportError: batched_load_balancing_loss = None clear_load_balancing_loss = None MoEArgs = None try: import wandb except ImportError: wandb = None from open_lm.data import sample_chunk from open_lm.distributed import is_master from open_lm.precision import get_autocast from open_lm.meters import AverageMeter def unwrap_model(model): if hasattr(model, "module"): return model.module else: return model def backward(total_loss, scaler): if scaler is not None: scaler.scale(total_loss).backward() else: total_loss.backward() def train_one_epoch( model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None ): """Trains model for one epoch on the provided data. Returns: success (bool): Whether training completed successfully step (int): Global step at the end of the epoch. Note that "epoch" actually is not one full pass through the data, but rather the number of tokens specified by `--train-num-samples`, rounded based on shard size. As such, the number of steps in an "epoch" can vary, and we have to keep track of steps separately. """ device = torch.device(args.device) autocast = get_autocast(args.precision) model.train() data["train"].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data["train"].dataloader num_batches_per_epoch = dataloader.num_batches sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) losses_m = AverageMeter() load_balancing_losses_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() forward_time_m = AverageMeter() backward_time_m = AverageMeter() optim_step_time_m = AverageMeter() sync_time_m = AverageMeter() if averagers is not None and args.log_avg_model_training_loss: losses_avg_m = {key: AverageMeter() for key in averagers.avgs_dict.keys()} local_avg_losses = {} total_loss_avg = {} # used only if --log-logit-mean flag is passed logit_m = AverageMeter() end = time.time() data_iterator = iter(dataloader) if args.moe_freq > 0: # these MoEArgs are necessary for logging load balancing. moe_args = MoEArgs( hidden_size=model.dim, ffn_hidden_size=model.dim * 4, moe_num_experts=args.moe_num_experts, num_layers=model.n_layers // args.moe_freq, moe_expert_model_parallelism=True, moe_top_k=args.moe_top_k, device=torch.cuda.current_device(), moe_capacity_factor=args.moe_capacity_factor, moe_loss_weight=args.moe_loss_weight, fp16=False, bf16=False, ) for i in itertools.count(): if not args.skip_scheduler: scheduler(step) if step >= total_steps: logging.warning(f"step: {step} has reached/exceeded total_steps: {total_steps}. ending training.") break try: batch = next(data_iterator) has_data = torch.tensor(1, dtype=torch.long, device=device) except StopIteration: has_data = torch.tensor(0, dtype=torch.long, device=device) if args.world_size > 1: dist.all_reduce(has_data, op=ReduceOp.SUM) # if is_master(args): # print("current has data", has_data) if has_data < args.world_size: break # (texts,) = batch # texts = torch.LongTensor(texts).to(device) data_time_m.update(time.time() - end) optimizer.zero_grad() if args.accum_freq == 1: with autocast(): forward_start = time.time() if args.dataset_type == "jsonl": inputs, targets = batch # for input in inputs: # max_label_length = max(len(l) for l in input) # mod_inputs = [] # mod_targets = [] # for input, target in zip(inputs, targets): # assert len(input) == len(target) # mod_inputs.append(input + [1] * (max_label_length - len(input))) # mod_targets.append(target + [-100] * (max_label_length - len(target))) inputs = torch.LongTensor(inputs).to(device) targets = torch.LongTensor(targets).to(device) inputs = inputs[:, :-1] targets = targets[:, 1:] assert inputs.size() == targets.size() if is_master(args): if i == 0: print("enter customed jsonl step") print("inputs id of first forward on") print("current inputs") print(inputs[:3, :]) print("current targets") print(targets[:3, :]) else: (texts,) = batch if is_master(args): pass texts = torch.LongTensor(texts).to(device) inputs, targets = sample_chunk(texts, args) out, _, _ = model(inputs) if is_master(args) and i == 0: pass forward_time_m.update(time.time() - forward_start) if args.log_logit_mean: logit_m.update(torch.mean(out).item()) total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) total_loss = total_lm_loss if args.moe_freq > 0: total_load_balancing_loss = batched_load_balancing_loss(moe_args) clear_load_balancing_loss() total_loss += total_load_balancing_loss backward_start = time.time() backward(total_loss, scaler) backward_time_m.update(time.time() - backward_start) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: with autocast(): for key, averager in averagers.avgs_dict.items(): with torch.no_grad(): out_avg, _, _ = averager.av_model(inputs) # save the loss for the average model for logging total_loss_avg[key] = loss(out_avg.reshape(-1, args.vocab_size), targets.reshape(-1)) else: # split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4 # then you only process 2 items at a time. batch-size must be divisible by accume-freq. assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq" per_batch = args.per_gpu_batch_size // args.accum_freq # inputs, targets = sample_chunk(texts, args) inputs, targets = batch forward_total_time = 0 backward_total_time = 0 for ii in range(args.accum_freq): maybe_no_sync = nullcontext # Don't sync gradients until the final batch for FSDP. if isinstance(model, FSDP) and ii != args.accum_freq - 1: maybe_no_sync = model.no_sync with maybe_no_sync(): with autocast(): forward_start = time.time() inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch] if inputs_ii.shape[0] == 0: break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] out, _, _ = model(inputs_ii) forward_total_time += time.time() - forward_start if args.log_logit_mean: logit_m.update(torch.mean(out).item()) local_lm_loss = ( loss(out.reshape(-1, args.vocab_size), targets_ii.reshape(-1)) * inputs_ii.shape[0] / inputs.shape[0] ) local_loss = local_lm_loss if args.moe_freq > 0: local_load_balancing_loss = batched_load_balancing_loss(moe_args) clear_load_balancing_loss() local_loss += local_load_balancing_loss backward_start = time.time() backward(local_loss, scaler) backward_total_time += time.time() - backward_start with autocast(): if ( averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0 ): for key, averager in averagers.avgs_dict.items(): with torch.no_grad(): out_avg, _, _ = averager.av_model(inputs_ii) local_avg_losses[key] = ( loss(out_avg.reshape(-1, args.vocab_size), targets_ii.reshape(-1)) * inputs_ii.shape[0] / inputs.shape[0] ) if ii == 0: total_lm_loss = local_lm_loss if args.moe_freq > 0: total_load_balancing_loss = local_load_balancing_loss if ( averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0 ): for key, averager in averagers.avgs_dict.items(): total_loss_avg[key] = local_avg_losses[key] else: total_lm_loss += local_lm_loss if args.moe_freq > 0: total_load_balancing_loss += local_load_balancing_loss if ( averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0 ): for key, averager in averagers.avgs_dict.items(): total_loss_avg[key] += local_avg_losses[key] forward_time_m.update(forward_total_time) backward_time_m.update(backward_total_time) total_loss = total_lm_loss if args.moe_freq > 0: total_loss += total_load_balancing_loss optim_step_start = time.time() if scaler is not None: if args.grad_clip_norm is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) scaler.step(optimizer) scaler.update() else: if args.grad_clip_norm is not None: if isinstance(model, FSDP): model.clip_grad_norm_(args.grad_clip_norm, norm_type=2.0) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) optimizer.step() optim_step_time_m.update(time.time() - optim_step_start) if averagers is not None: averagers.step() global_loss_tensor = total_loss.detach().clone() if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: # same for the average model loss for key, value in total_loss_avg.items(): total_loss_avg[key] = value.detach().clone() sync_start = time.time() if args.world_size > 1: dist.all_reduce(global_loss_tensor, op=ReduceOp.AVG) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: for key, value in total_loss_avg.items(): dist.all_reduce(value, op=ReduceOp.AVG) if args.moe_freq > 0: dist.all_reduce(total_load_balancing_loss, op=ReduceOp.AVG) sync_time_m.update(time.time() - sync_start) batch_time_m.update(time.time() - end) end = time.time() batch_count = i + 1 step += 1 if is_master(args): batch_size = len(inputs) if args.moe_freq > 0: losses_m.update(global_loss_tensor.item() - total_load_balancing_loss.item(), batch_size) load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size) else: losses_m.update(global_loss_tensor.item(), batch_size) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: for key, value in total_loss_avg.items(): losses_avg_m[key].update(value.item(), batch_size) if i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch or step == total_steps - 1: num_samples = batch_count * batch_size * args.world_size samples_per_epoch = dataloader.num_samples percent_complete = 100.0 * batch_count / num_batches_per_epoch # gathered_loss = [torch.zeros_like(total_loss) for _ in range(args.world_size)] # torch.distributed.all_gather(gathered_loss, total_loss) # losses_m.update(sum(gathered_loss).item() / args.world_size, batch_size * args.world_size) if args.moe_freq > 0: losses_m.update(global_loss_tensor.item() - total_load_balancing_loss.item(), batch_size) load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size) else: losses_m.update(global_loss_tensor.item(), batch_size) samples_per_second = inputs.numel() * args.world_size / batch_time_m.val samples_per_second_per_gpu = inputs.numel() / batch_time_m.val loss_str = f"Loss: {losses_m.avg:.3f}" loss_str += f" LB-Loss: {load_balancing_losses_m.avg:.3f}" if args.moe_freq > 0 else "" logging.info( f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " f"{loss_str} " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " f"LR: {optimizer.param_groups[0]['lr']:5f} " ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { "loss": losses_m.val, "load_balancing_loss": load_balancing_losses_m.val, "data_time": data_time_m.val, "batch_time": batch_time_m.val, "forward_time": forward_time_m.val, "backward_time": backward_time_m.val, "optim_step_time": optim_step_time_m.val, "sync_time": sync_time_m.val, "samples_per_second": samples_per_second, "samples_per_second_per_gpu": samples_per_second_per_gpu, "lr": optimizer.param_groups[0]["lr"], "tokens": (step + 1) * args.global_batch_size * args.seq_len, "expected_steps_epoch": data["train"].dataloader.num_batches, "seen_steps_epoch": batch_count, } if averagers is not None and args.log_avg_model_training_loss: for k in averagers.avgs_dict.keys(): if ( averagers is not None and args.log_avg_model_training_loss and (i % args.log_avg_model_training_loss == 0 or batch_count == num_batches_per_epoch) ): log_data[k + "_loss"] = losses_avg_m[k].avg if args.log_logit_mean: log_data["logit_mean"] = logit_m.val for name, val in log_data.items(): name = "train/" + name if tb_writer is not None: tb_writer.add_scalar(name, val, step) if args.wandb: assert wandb is not None, "Please install wandb." wandb.log({name: val, "step": step, "tokens": log_data["tokens"]}) # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() forward_time_m.reset() backward_time_m.reset() optim_step_time_m.reset() sync_time_m.reset() if math.isnan(losses_m.val): # case where loss goes to nan, we see this sometimes with bad nodes. # in this case we would like to free resources and prevent other issues # e.g., saving checkpoints and optmization states that may lead to skipped # training on restarts. return False, step # reset all average meters losses_m.reset() if averagers is not None and args.log_avg_model_training_loss: for k in averagers.avgs_dict.keys(): losses_avg_m[k].reset() # end for if tb_writer is not None: tb_writer.flush() return True, step