File size: 18,757 Bytes
af6e330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import itertools
import logging
import math
import time
from contextlib import nullcontext

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

try:
    from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss
    from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
    batched_load_balancing_loss = None
    clear_load_balancing_loss = None
    MoEArgs = None

try:
    import wandb
except ImportError:
    wandb = None

from open_lm.data import sample_chunk
from open_lm.distributed import is_master
from open_lm.precision import get_autocast
from open_lm.meters import AverageMeter


def unwrap_model(model):
    if hasattr(model, "module"):
        return model.module
    else:
        return model


def backward(total_loss, scaler):
    if scaler is not None:
        scaler.scale(total_loss).backward()
    else:
        total_loss.backward()


def train_one_epoch(
    model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None
):
    """Trains model for one epoch on the provided data.

    Returns:
        success (bool): Whether training completed successfully
        step (int): Global step at the end of the epoch. Note that "epoch" actually is not one full pass through the
            data, but rather the number of tokens specified by `--train-num-samples`, rounded based on shard size.
            As such, the number of steps in an "epoch" can vary, and we have to keep track of steps separately.
    """
    device = torch.device(args.device)
    autocast = get_autocast(args.precision)

    model.train()

    data["train"].set_epoch(epoch)  # set epoch in process safe manner via sampler or shared_epoch
    dataloader = data["train"].dataloader
    num_batches_per_epoch = dataloader.num_batches

    sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
    losses_m = AverageMeter()
    load_balancing_losses_m = AverageMeter()
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    forward_time_m = AverageMeter()
    backward_time_m = AverageMeter()
    optim_step_time_m = AverageMeter()
    sync_time_m = AverageMeter()
    if averagers is not None and args.log_avg_model_training_loss:
        losses_avg_m = {key: AverageMeter() for key in averagers.avgs_dict.keys()}
        local_avg_losses = {}
        total_loss_avg = {}

    # used only if --log-logit-mean flag is passed
    logit_m = AverageMeter()

    end = time.time()

    data_iterator = iter(dataloader)

    if args.moe_freq > 0:
        # these MoEArgs are necessary for logging load balancing.
        moe_args = MoEArgs(
            hidden_size=model.dim,
            ffn_hidden_size=model.dim * 4,
            moe_num_experts=args.moe_num_experts,
            num_layers=model.n_layers // args.moe_freq,
            moe_expert_model_parallelism=True,
            moe_top_k=args.moe_top_k,
            device=torch.cuda.current_device(),
            moe_capacity_factor=args.moe_capacity_factor,
            moe_loss_weight=args.moe_loss_weight,
            fp16=False,
            bf16=False,
        )

    for i in itertools.count():
        if not args.skip_scheduler:
            scheduler(step)

        if step >= total_steps:
            logging.warning(f"step: {step} has reached/exceeded total_steps: {total_steps}. ending training.")
            break

        try:
            batch = next(data_iterator)
            has_data = torch.tensor(1, dtype=torch.long, device=device)
        except StopIteration:
            has_data = torch.tensor(0, dtype=torch.long, device=device)

        if args.world_size > 1:
            dist.all_reduce(has_data, op=ReduceOp.SUM)
        # if is_master(args):
        #     print("current has data", has_data)
        if has_data < args.world_size:
            break

        # (texts,) = batch

        # texts = torch.LongTensor(texts).to(device)
        data_time_m.update(time.time() - end)
        optimizer.zero_grad()
        if args.accum_freq == 1:
            with autocast():
                forward_start = time.time()
                if args.dataset_type == "jsonl":
                    inputs, targets = batch
                    # for input in inputs:
                    #     max_label_length = max(len(l) for l in input)
                    # mod_inputs = []
                    # mod_targets = []
                    # for input, target in zip(inputs, targets):
                    #     assert len(input) == len(target)
                    #     mod_inputs.append(input + [1] * (max_label_length - len(input)))
                    #     mod_targets.append(target + [-100] * (max_label_length - len(target)))
                    inputs = torch.LongTensor(inputs).to(device)
                    targets = torch.LongTensor(targets).to(device)
                    inputs = inputs[:, :-1]
                    targets = targets[:, 1:]
                    assert inputs.size() == targets.size()
                    if is_master(args):
                        if i == 0:
                            print("enter customed jsonl step")
                            print("inputs id of first forward on")
                            print("current inputs")
                            print(inputs[:3, :])
                            print("current targets")
                            print(targets[:3, :])
                else:
                    (texts,) = batch
                    if is_master(args):
                        pass
                    texts = torch.LongTensor(texts).to(device)
                    inputs, targets = sample_chunk(texts, args)
                out, _, _ = model(inputs)
                if is_master(args) and i == 0:
                    pass
                forward_time_m.update(time.time() - forward_start)

                if args.log_logit_mean:
                    logit_m.update(torch.mean(out).item())
                total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1))
                total_loss = total_lm_loss
                if args.moe_freq > 0:
                    total_load_balancing_loss = batched_load_balancing_loss(moe_args)
                    clear_load_balancing_loss()
                    total_loss += total_load_balancing_loss
            backward_start = time.time()
            backward(total_loss, scaler)
            backward_time_m.update(time.time() - backward_start)

            if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0:
                with autocast():
                    for key, averager in averagers.avgs_dict.items():
                        with torch.no_grad():
                            out_avg, _, _ = averager.av_model(inputs)
                            # save the loss for the average model for logging
                            total_loss_avg[key] = loss(out_avg.reshape(-1, args.vocab_size), targets.reshape(-1))
        else:
            # split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4
            # then you only process 2 items at a time. batch-size must be divisible by accume-freq.
            assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
            per_batch = args.per_gpu_batch_size // args.accum_freq

            # inputs, targets = sample_chunk(texts, args)
            inputs, targets = batch
            
            forward_total_time = 0
            backward_total_time = 0
            for ii in range(args.accum_freq):
                maybe_no_sync = nullcontext
                # Don't sync gradients until the final batch for FSDP.
                if isinstance(model, FSDP) and ii != args.accum_freq - 1:
                    maybe_no_sync = model.no_sync
                with maybe_no_sync():
                    with autocast():
                        forward_start = time.time()
                        inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch]
                        if inputs_ii.shape[0] == 0:
                            break
                        targets_ii = targets[ii * per_batch : (ii + 1) * per_batch]
                        out, _, _ = model(inputs_ii)
                        forward_total_time += time.time() - forward_start

                        if args.log_logit_mean:
                            logit_m.update(torch.mean(out).item())

                        local_lm_loss = (
                            loss(out.reshape(-1, args.vocab_size), targets_ii.reshape(-1))
                            * inputs_ii.shape[0]
                            / inputs.shape[0]
                        )
                    local_loss = local_lm_loss
                    if args.moe_freq > 0:
                        local_load_balancing_loss = batched_load_balancing_loss(moe_args)
                        clear_load_balancing_loss()
                        local_loss += local_load_balancing_loss

                    backward_start = time.time()
                    backward(local_loss, scaler)
                    backward_total_time += time.time() - backward_start
                    with autocast():
                        if (
                            averagers is not None
                            and args.log_avg_model_training_loss
                            and i % args.log_avg_model_training_loss == 0
                        ):
                            for key, averager in averagers.avgs_dict.items():
                                with torch.no_grad():
                                    out_avg, _, _ = averager.av_model(inputs_ii)
                                    local_avg_losses[key] = (
                                        loss(out_avg.reshape(-1, args.vocab_size), targets_ii.reshape(-1))
                                        * inputs_ii.shape[0]
                                        / inputs.shape[0]
                                    )
                if ii == 0:
                    total_lm_loss = local_lm_loss
                    if args.moe_freq > 0:
                        total_load_balancing_loss = local_load_balancing_loss
                    if (
                        averagers is not None
                        and args.log_avg_model_training_loss
                        and i % args.log_avg_model_training_loss == 0
                    ):
                        for key, averager in averagers.avgs_dict.items():
                            total_loss_avg[key] = local_avg_losses[key]
                else:
                    total_lm_loss += local_lm_loss
                    if args.moe_freq > 0:
                        total_load_balancing_loss += local_load_balancing_loss
                    if (
                        averagers is not None
                        and args.log_avg_model_training_loss
                        and i % args.log_avg_model_training_loss == 0
                    ):
                        for key, averager in averagers.avgs_dict.items():
                            total_loss_avg[key] += local_avg_losses[key]

            forward_time_m.update(forward_total_time)
            backward_time_m.update(backward_total_time)

            total_loss = total_lm_loss
            if args.moe_freq > 0:
                total_loss += total_load_balancing_loss

        optim_step_start = time.time()
        if scaler is not None:
            if args.grad_clip_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            if args.grad_clip_norm is not None:
                if isinstance(model, FSDP):
                    model.clip_grad_norm_(args.grad_clip_norm, norm_type=2.0)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
            optimizer.step()
        optim_step_time_m.update(time.time() - optim_step_start)

        if averagers is not None:
            averagers.step()

        global_loss_tensor = total_loss.detach().clone()
        if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0:
            # same for the average model loss
            for key, value in total_loss_avg.items():
                total_loss_avg[key] = value.detach().clone()

        sync_start = time.time()
        if args.world_size > 1:
            dist.all_reduce(global_loss_tensor, op=ReduceOp.AVG)
            if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0:
                for key, value in total_loss_avg.items():
                    dist.all_reduce(value, op=ReduceOp.AVG)
            if args.moe_freq > 0:
                dist.all_reduce(total_load_balancing_loss, op=ReduceOp.AVG)
        sync_time_m.update(time.time() - sync_start)

        batch_time_m.update(time.time() - end)
        end = time.time()

        batch_count = i + 1
        step += 1
        if is_master(args):
            batch_size = len(inputs)
            if args.moe_freq > 0:
                losses_m.update(global_loss_tensor.item() - total_load_balancing_loss.item(), batch_size)
                load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size)
            else:
                losses_m.update(global_loss_tensor.item(), batch_size)
            if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0:
                for key, value in total_loss_avg.items():
                    losses_avg_m[key].update(value.item(), batch_size)
            if i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch or step == total_steps - 1:
                num_samples = batch_count * batch_size * args.world_size
                samples_per_epoch = dataloader.num_samples
                percent_complete = 100.0 * batch_count / num_batches_per_epoch

                # gathered_loss = [torch.zeros_like(total_loss) for _ in range(args.world_size)]
                # torch.distributed.all_gather(gathered_loss, total_loss)

                # losses_m.update(sum(gathered_loss).item() / args.world_size, batch_size * args.world_size)
                if args.moe_freq > 0:
                    losses_m.update(global_loss_tensor.item() - total_load_balancing_loss.item(), batch_size)
                    load_balancing_losses_m.update(total_load_balancing_loss.item(), batch_size)
                else:
                    losses_m.update(global_loss_tensor.item(), batch_size)
                samples_per_second = inputs.numel() * args.world_size / batch_time_m.val
                samples_per_second_per_gpu = inputs.numel() / batch_time_m.val
                loss_str = f"Loss: {losses_m.avg:.3f}"
                loss_str += f" LB-Loss: {load_balancing_losses_m.avg:.3f}" if args.moe_freq > 0 else ""
                logging.info(
                    f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                    f"{loss_str} "
                    f"Data (t): {data_time_m.avg:.3f} "
                    f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
                    f"LR: {optimizer.param_groups[0]['lr']:5f} "
                )

                # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
                log_data = {
                    "loss": losses_m.val,
                    "load_balancing_loss": load_balancing_losses_m.val,
                    "data_time": data_time_m.val,
                    "batch_time": batch_time_m.val,
                    "forward_time": forward_time_m.val,
                    "backward_time": backward_time_m.val,
                    "optim_step_time": optim_step_time_m.val,
                    "sync_time": sync_time_m.val,
                    "samples_per_second": samples_per_second,
                    "samples_per_second_per_gpu": samples_per_second_per_gpu,
                    "lr": optimizer.param_groups[0]["lr"],
                    "tokens": (step + 1) * args.global_batch_size * args.seq_len,
                    "expected_steps_epoch": data["train"].dataloader.num_batches,
                    "seen_steps_epoch": batch_count,
                }

                if averagers is not None and args.log_avg_model_training_loss:
                    for k in averagers.avgs_dict.keys():
                        if (
                            averagers is not None
                            and args.log_avg_model_training_loss
                            and (i % args.log_avg_model_training_loss == 0 or batch_count == num_batches_per_epoch)
                        ):
                            log_data[k + "_loss"] = losses_avg_m[k].avg
                if args.log_logit_mean:
                    log_data["logit_mean"] = logit_m.val

                for name, val in log_data.items():
                    name = "train/" + name
                    if tb_writer is not None:
                        tb_writer.add_scalar(name, val, step)
                    if args.wandb:
                        assert wandb is not None, "Please install wandb."
                        wandb.log({name: val, "step": step, "tokens": log_data["tokens"]})

                # resetting batch / data time meters per log window
                batch_time_m.reset()
                data_time_m.reset()
                forward_time_m.reset()
                backward_time_m.reset()
                optim_step_time_m.reset()
                sync_time_m.reset()

                if math.isnan(losses_m.val):
                    # case where loss goes to nan, we see this sometimes with bad nodes.
                    # in this case we would like to free resources and prevent other issues
                    # e.g., saving checkpoints and optmization states that may lead to skipped
                    # training on restarts.
                    return False, step

                # reset all average meters
                losses_m.reset()
                if averagers is not None and args.log_avg_model_training_loss:
                    for k in averagers.avgs_dict.keys():
                        losses_avg_m[k].reset()

    # end for
    if tb_writer is not None:
        tb_writer.flush()
    return True, step