|
import itertools |
|
import logging |
|
import math |
|
import time |
|
from contextlib import nullcontext |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from torch.distributed.distributed_c10d import ReduceOp |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
|
try: |
|
from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss |
|
from megablocks.layers.arguments import Arguments as MoEArgs |
|
except ImportError: |
|
batched_load_balancing_loss = None |
|
clear_load_balancing_loss = None |
|
MoEArgs = None |
|
|
|
try: |
|
import wandb |
|
except ImportError: |
|
wandb = None |
|
|
|
from open_lm.data import sample_chunk |
|
from open_lm.distributed import is_master |
|
from open_lm.precision import get_autocast |
|
from open_lm.meters import AverageMeter |
|
|
|
|
|
def unwrap_model(model): |
|
if hasattr(model, "module"): |
|
return model.module |
|
else: |
|
return model |
|
|
|
|
|
def backward(total_loss, scaler): |
|
if scaler is not None: |
|
scaler.scale(total_loss).backward() |
|
else: |
|
total_loss.backward() |
|
|
|
|
|
def train_one_epoch( |
|
model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None |
|
): |
|
"""Trains model for one epoch on the provided data. |
|
|
|
Returns: |
|
success (bool): Whether training completed successfully |
|
step (int): Global step at the end of the epoch. Note that "epoch" actually is not one full pass through the |
|
data, but rather the number of tokens specified by `--train-num-samples`, rounded based on shard size. |
|
As such, the number of steps in an "epoch" can vary, and we have to keep track of steps separately. |
|
""" |
|
device = torch.device(args.device) |
|
autocast = get_autocast(args.precision) |
|
|
|
model.train() |
|
|
|
data["train"].set_epoch(epoch) |
|
dataloader = data["train"].dataloader |
|
num_batches_per_epoch = dataloader.num_batches |
|
|
|
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
|
losses_m = AverageMeter() |
|
load_balancing_losses_m = AverageMeter() |
|
batch_time_m = AverageMeter() |
|
data_time_m = AverageMeter() |
|
forward_time_m = AverageMeter() |
|
backward_time_m = AverageMeter() |
|
optim_step_time_m = AverageMeter() |
|
sync_time_m = AverageMeter() |
|
if averagers is not None and args.log_avg_model_training_loss: |
|
losses_avg_m = {key: AverageMeter() for key in averagers.avgs_dict.keys()} |
|
local_avg_losses = {} |
|
total_loss_avg = {} |
|
|
|
|
|
logit_m = AverageMeter() |
|
|
|
end = time.time() |
|
|
|
data_iterator = iter(dataloader) |
|
|
|
if args.moe_freq > 0: |
|
|
|
moe_args = MoEArgs( |
|
hidden_size=model.dim, |
|
ffn_hidden_size=model.dim * 4, |
|
moe_num_experts=args.moe_num_experts, |
|
num_layers=model.n_layers // args.moe_freq, |
|
moe_expert_model_parallelism=True, |
|
moe_top_k=args.moe_top_k, |
|
device=torch.cuda.current_device(), |
|
moe_capacity_factor=args.moe_capacity_factor, |
|
moe_loss_weight=args.moe_loss_weight, |
|
fp16=False, |
|
bf16=False, |
|
) |
|
|
|
for i in itertools.count(): |
|
if not args.skip_scheduler: |
|
scheduler(step) |
|
|
|
if step >= total_steps: |
|
logging.warning(f"step: {step} has reached/exceeded total_steps: {total_steps}. ending training.") |
|
break |
|
|
|
try: |
|
batch = next(data_iterator) |
|
has_data = torch.tensor(1, dtype=torch.long, device=device) |
|
except StopIteration: |
|
has_data = torch.tensor(0, dtype=torch.long, device=device) |
|
|
|
if args.world_size > 1: |
|
dist.all_reduce(has_data, op=ReduceOp.SUM) |
|
|
|
|
|
if has_data < args.world_size: |
|
break |
|
|
|
|
|
|
|
|
|
data_time_m.update(time.time() - end) |
|
optimizer.zero_grad() |
|
if args.accum_freq == 1: |
|
with autocast(): |
|
forward_start = time.time() |
|
if args.dataset_type == "jsonl": |
|
inputs, targets = batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = torch.LongTensor(inputs).to(device) |
|
targets = torch.LongTensor(targets).to(device) |
|
inputs = inputs[:, :-1] |
|
targets = targets[:, 1:] |
|
assert inputs.size() == targets.size() |
|
if is_master(args): |
|
if i == 0: |
|
print("enter customed jsonl step") |
|
print("inputs id of first forward on") |
|
print("current inputs") |
|
print(inputs[:3, :]) |
|
print("current targets") |
|
print(targets[:3, :]) |
|
else: |
|
(texts,) = batch |
|
if is_master(args): |
|
pass |
|
texts = torch.LongTensor(texts).to(device) |
|
inputs, targets = sample_chunk(texts, args) |
|
out, _, _ = model(inputs) |
|
if is_master(args) and i == 0: |
|
pass |
|
forward_time_m.update(time.time() - forward_start) |
|
|
|
if args.log_logit_mean: |
|
logit_m.update(torch.mean(out).item()) |
|
total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) |
|
total_loss = total_lm_loss |
|
if args.moe_freq > 0: |
|
total_load_balancing_loss = batched_load_balancing_loss(moe_args) |
|
clear_load_balancing_loss() |
|
total_loss += total_load_balancing_loss |
|
backward_start = time.time() |
|
backward(total_loss, scaler) |
|
backward_time_m.update(time.time() - backward_start) |
|
|
|
if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: |
|
with autocast(): |
|
for key, averager in averagers.avgs_dict.items(): |
|
with torch.no_grad(): |
|
out_avg, _, _ = averager.av_model(inputs) |
|
|
|
total_loss_avg[key] = loss(out_avg.reshape(-1, args.vocab_size), targets.reshape(-1)) |
|
else: |
|
|
|
|
|
assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq" |
|
per_batch = args.per_gpu_batch_size // args.accum_freq |
|
|
|
|
|
inputs, targets = batch |
|
|
|
forward_total_time = 0 |
|
backward_total_time = 0 |
|
for ii in range(args.accum_freq): |
|
maybe_no_sync = nullcontext |
|
|
|
if isinstance(model, FSDP) and ii != args.accum_freq - 1: |
|
maybe_no_sync = model.no_sync |
|
with maybe_no_sync(): |
|
with autocast(): |
|
forward_start = time.time() |
|
inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch] |
|
if inputs_ii.shape[0] == 0: |
|
break |
|
targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] |
|
out, _, _ = model(inputs_ii) |
|
forward_total_time += time.time() - forward_start |
|
|
|
if args.log_logit_mean: |
|
logit_m.update(torch.mean(out).item()) |
|
|
|
local_lm_loss = ( |
|
loss(out.reshape(-1, args.vocab_size), targets_ii.reshape(-1)) |
|
* inputs_ii.shape[0] |
|
/ inputs.shape[0] |
|
) |
|
local_loss = local_lm_loss |
|
if args.moe_freq > 0: |
|
local_load_balancing_loss = batched_load_balancing_loss(moe_args) |
|
clear_load_balancing_loss() |
|
local_loss += local_load_balancing_loss |
|
|
|
backward_start = time.time() |
|
backward(local_loss, scaler) |
|
backward_total_time += time.time() - backward_start |
|
with autocast(): |
|
if ( |
|
averagers is not None |
|
and args.log_avg_model_training_loss |
|
and i % args.log_avg_model_training_loss == 0 |
|
): |
|
for key, averager in averagers.avgs_dict.items(): |
|
with torch.no_grad(): |
|
out_avg, _, _ = averager.av_model(inputs_ii) |
|
local_avg_losses[key] = ( |
|
loss(out_avg.reshape(-1, args.vocab_size), targets_ii.reshape(-1)) |
|
* inputs_ii.shape[0] |
|
/ inputs.shape[0] |
|
) |
|
if ii == 0: |
|
total_lm_loss = local_lm_loss |
|
if args.moe_freq > 0: |
|
total_load_balancing_loss = local_load_balancing_loss |
|
if ( |
|
averagers is not None |
|
and args.log_avg_model_training_loss |
|
and i % args.log_avg_model_training_loss == 0 |
|
): |
|
for key, averager in averagers.avgs_dict.items(): |
|
total_loss_avg[key] = local_avg_losses[key] |
|
else: |
|
total_lm_loss += local_lm_loss |
|
if args.moe_freq > 0: |
|
total_load_balancing_loss += local_load_balancing_loss |
|
if ( |
|
averagers is not None |
|
and args.log_avg_model_training_loss |
|
and i % args.log_avg_model_training_loss == 0 |
|
): |
|
for key, averager in averagers.avgs_dict.items(): |
|
total_loss_avg[key] += local_avg_losses[key] |
|
|
|
forward_time_m.update(forward_total_time) |
|
backward_time_m.update(backward_total_time) |
|
|
|
total_loss = total_lm_loss |
|
if args.moe_freq > 0: |
|
total_loss += total_load_balancing_loss |
|
|
|
optim_step_start = time.time() |
|
if scaler is not None: |
|
if args.grad_clip_norm is not None: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
if args.grad_clip_norm is not None: |
|
if isinstance(model, FSDP): |
|
model.clip_grad_norm_(args.grad_clip_norm, norm_type=2.0) |
|
else: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) |
|
optimizer.step() |
|
optim_step_time_m.update(time.time() - optim_step_start) |
|
|
|
if averagers is not None: |
|
averagers.step() |
|
|
|
global_loss_tensor = total_loss.detach().clone() |
|
if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: |
|
|
|
for key, value in total_loss_avg.items(): |
|
total_loss_avg[key] = value.detach().clone() |
|
|
|
sync_start = time.time() |
|
if args.world_size > 1: |
|
dist.all_reduce(global_loss_tensor, op=ReduceOp.AVG) |
|
if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: |
|
for key, value in total_loss_avg.items(): |
|
dist.all_reduce(value, op=ReduceOp.AVG) |
|
if args.moe_freq > 0: |
|
dist.all_reduce(total_load_balancing_loss, op=ReduceOp.AVG) |
|
sync_time_m.update(time.time() - sync_start) |
|
|
|
batch_time_m.update(time.time() - end) |
|
end = time.time() |
|
|
|
batch_count = i + 1 |
|
step += 1 |
|
if is_master(args): |
|
batch_size = len(inputs) |
|
if args.moe_freq > 0: |
|
losses_m.update(global_loss_tensor.item() - total_load_balancing_loss.item(), batch_size) |
|
load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size) |
|
else: |
|
losses_m.update(global_loss_tensor.item(), batch_size) |
|
if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: |
|
for key, value in total_loss_avg.items(): |
|
losses_avg_m[key].update(value.item(), batch_size) |
|
if i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch or step == total_steps - 1: |
|
num_samples = batch_count * batch_size * args.world_size |
|
samples_per_epoch = dataloader.num_samples |
|
percent_complete = 100.0 * batch_count / num_batches_per_epoch |
|
|
|
|
|
|
|
|
|
|
|
if args.moe_freq > 0: |
|
losses_m.update(global_loss_tensor.item() - total_load_balancing_loss.item(), batch_size) |
|
load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size) |
|
else: |
|
losses_m.update(global_loss_tensor.item(), batch_size) |
|
samples_per_second = inputs.numel() * args.world_size / batch_time_m.val |
|
samples_per_second_per_gpu = inputs.numel() / batch_time_m.val |
|
loss_str = f"Loss: {losses_m.avg:.3f}" |
|
loss_str += f" LB-Loss: {load_balancing_losses_m.avg:.3f}" if args.moe_freq > 0 else "" |
|
logging.info( |
|
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
|
f"{loss_str} " |
|
f"Data (t): {data_time_m.avg:.3f} " |
|
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " |
|
f"LR: {optimizer.param_groups[0]['lr']:5f} " |
|
) |
|
|
|
|
|
log_data = { |
|
"loss": losses_m.val, |
|
"load_balancing_loss": load_balancing_losses_m.val, |
|
"data_time": data_time_m.val, |
|
"batch_time": batch_time_m.val, |
|
"forward_time": forward_time_m.val, |
|
"backward_time": backward_time_m.val, |
|
"optim_step_time": optim_step_time_m.val, |
|
"sync_time": sync_time_m.val, |
|
"samples_per_second": samples_per_second, |
|
"samples_per_second_per_gpu": samples_per_second_per_gpu, |
|
"lr": optimizer.param_groups[0]["lr"], |
|
"tokens": (step + 1) * args.global_batch_size * args.seq_len, |
|
"expected_steps_epoch": data["train"].dataloader.num_batches, |
|
"seen_steps_epoch": batch_count, |
|
} |
|
|
|
if averagers is not None and args.log_avg_model_training_loss: |
|
for k in averagers.avgs_dict.keys(): |
|
if ( |
|
averagers is not None |
|
and args.log_avg_model_training_loss |
|
and (i % args.log_avg_model_training_loss == 0 or batch_count == num_batches_per_epoch) |
|
): |
|
log_data[k + "_loss"] = losses_avg_m[k].avg |
|
if args.log_logit_mean: |
|
log_data["logit_mean"] = logit_m.val |
|
|
|
for name, val in log_data.items(): |
|
name = "train/" + name |
|
if tb_writer is not None: |
|
tb_writer.add_scalar(name, val, step) |
|
if args.wandb: |
|
assert wandb is not None, "Please install wandb." |
|
wandb.log({name: val, "step": step, "tokens": log_data["tokens"]}) |
|
|
|
|
|
batch_time_m.reset() |
|
data_time_m.reset() |
|
forward_time_m.reset() |
|
backward_time_m.reset() |
|
optim_step_time_m.reset() |
|
sync_time_m.reset() |
|
|
|
if math.isnan(losses_m.val): |
|
|
|
|
|
|
|
|
|
return False, step |
|
|
|
|
|
losses_m.reset() |
|
if averagers is not None and args.log_avg_model_training_loss: |
|
for k in averagers.avgs_dict.keys(): |
|
losses_avg_m[k].reset() |
|
|
|
|
|
if tb_writer is not None: |
|
tb_writer.flush() |
|
return True, step |
|
|