| | """ |
| | Train model. From root directory of the project, run as: |
| | |
| | python -m scripts.base_train |
| | |
| | or distributed as: |
| | |
| | torchrun --nproc_per_node=8 -m scripts.base_train |
| | |
| | If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: |
| | python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 |
| | """ |
| |
|
| | import gc |
| | import os |
| | os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" |
| | import argparse |
| | import time |
| | from contextlib import nullcontext, contextmanager |
| |
|
| | import wandb |
| | import torch |
| |
|
| | from nanochat.gpt import GPT, GPTConfig |
| | from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit |
| | from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops |
| | from nanochat.tokenizer import get_tokenizer, get_token_bytes |
| | from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint |
| | from nanochat.loss_eval import evaluate_bpb |
| | from nanochat.engine import Engine |
| | from nanochat.flash_attention import HAS_FA3 |
| | from scripts.base_eval import evaluate_core |
| | print_banner() |
| |
|
| | |
| | |
| | parser = argparse.ArgumentParser(description="Pretrain base model") |
| | |
| | parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") |
| | |
| | parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") |
| | |
| | parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") |
| | parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") |
| | |
| | parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") |
| | parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") |
| | parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") |
| | parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") |
| | parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") |
| | |
| | parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") |
| | parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") |
| | parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") |
| | |
| | parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") |
| | parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") |
| | parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") |
| | parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") |
| | parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)") |
| | parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") |
| | parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") |
| | parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding") |
| | parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding") |
| | parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup") |
| | parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown") |
| | parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR") |
| | parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") |
| | |
| | parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") |
| | parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on") |
| | parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") |
| | parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") |
| | parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") |
| | parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") |
| | |
| | parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") |
| | args = parser.parse_args() |
| | user_config = vars(args).copy() |
| | |
| |
|
| | |
| | device_type = autodetect_device_type() if args.device_type == "" else args.device_type |
| | ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) |
| | master_process = ddp_rank == 0 |
| | autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() |
| | synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None |
| | get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 |
| | if device_type == "cuda": |
| | gpu_device_name = torch.cuda.get_device_name(0) |
| | gpu_peak_flops = get_peak_flops(gpu_device_name) |
| | print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") |
| | else: |
| | gpu_peak_flops = float('inf') |
| |
|
| | |
| | use_dummy_wandb = args.run == "dummy" or not master_process |
| | wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) |
| |
|
| | |
| | if HAS_FA3: |
| | print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") |
| | else: |
| | print0("!" * 80) |
| | print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") |
| | print0("WARNING: Training will be less efficient without FA3") |
| | if args.window_pattern != "L": |
| | print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") |
| | print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") |
| | print0("!" * 80) |
| |
|
| | |
| | tokenizer = get_tokenizer() |
| | token_bytes = get_token_bytes(device=device) |
| | vocab_size = tokenizer.get_vocab_size() |
| | print0(f"Vocab size: {vocab_size:,}") |
| |
|
| | |
| | |
| | |
| | |
| | num_layers = args.depth |
| | base_dim = args.depth * args.aspect_ratio |
| | model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim |
| | num_heads = model_dim // args.head_dim |
| | num_kv_heads = num_heads |
| | head_dim = model_dim // num_heads |
| | print0(f"num_layers: {num_layers}") |
| | print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") |
| | print0(f"num_heads: {num_heads}") |
| | print0(f"head_dim: {head_dim}") |
| | print0(f"num_kv_heads: {num_kv_heads}") |
| |
|
| | |
| | |
| | tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len |
| | world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size |
| | assert args.total_batch_size % world_tokens_per_fwdbwd == 0 |
| | grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd |
| | print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") |
| | print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") |
| | print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") |
| |
|
| | |
| | batch_lr_scale = 1.0 |
| | reference_batch_size = 2**19 |
| | batch_ratio = args.total_batch_size / reference_batch_size |
| | if batch_ratio != 1.0: |
| | |
| | |
| | |
| | batch_lr_scale = batch_ratio ** 0.5 |
| | print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})") |
| |
|
| | |
| | weight_decay_scaled = args.weight_decay * (12 / args.depth)**2 |
| | if args.depth != 12: |
| | print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") |
| |
|
| | |
| | |
| |
|
| | |
| | model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern) |
| | with torch.device("meta"): |
| | |
| | model_config = GPTConfig(**model_config_kwargs) |
| | model = GPT(model_config) |
| | model.to_empty(device=device) |
| | model.init_weights() |
| |
|
| | |
| | base_dir = get_base_dir() |
| | output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" |
| | checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) |
| | resuming = args.resume_from_step != -1 |
| | if resuming: |
| | print0(f"Resuming optimization from step {args.resume_from_step}") |
| | model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank) |
| | model.load_state_dict(model_data, strict=True, assign=True) |
| | del model_data |
| |
|
| | |
| | |
| |
|
| | |
| | param_counts = model.num_scaling_params() |
| | print0(f"Parameter counts:") |
| | for key, value in param_counts.items(): |
| | print0(f"{key:24s}: {value:,}") |
| | num_params = param_counts['total'] |
| | num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] |
| | num_flops_per_token = model.estimate_flops() |
| | print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") |
| |
|
| | |
| | assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 |
| | if args.num_iterations > 0: |
| | num_iterations = args.num_iterations |
| | print0(f"Using user-provided number of iterations: {num_iterations:,}") |
| | elif args.target_flops > 0: |
| | |
| | num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size)) |
| | print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") |
| | elif args.target_param_data_ratio > 0: |
| | |
| | target_tokens = int(args.target_param_data_ratio * num_scaling_params) |
| | num_iterations = target_tokens // args.total_batch_size |
| | print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") |
| | else: |
| | raise ValueError("No training horizon specified") |
| | total_tokens = args.total_batch_size * num_iterations |
| | print0(f"Total number of training tokens: {total_tokens:,}") |
| | print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") |
| | print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") |
| |
|
| | |
| | |
| |
|
| | |
| | if args.fp8: |
| | if device_type != "cuda": |
| | print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") |
| | else: |
| | from torchao.float8 import Float8LinearConfig, convert_to_float8_training |
| | import torch.nn as nn |
| |
|
| | |
| | def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: |
| | if not isinstance(mod, nn.Linear): |
| | return False |
| | |
| | if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: |
| | return False |
| | return True |
| |
|
| | fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) |
| | convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) |
| | num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) |
| | num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers |
| | print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") |
| |
|
| | |
| | @contextmanager |
| | def disable_fp8(model): |
| | """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation. |
| | |
| | CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead, |
| | we swap out Float8Linear modules entirely and restore them after. |
| | """ |
| | import torch.nn as nn |
| |
|
| | |
| | fp8_locations = [] |
| | for name, module in model.named_modules(): |
| | if 'Float8' in type(module).__name__: |
| | if '.' in name: |
| | parent_name, attr_name = name.rsplit('.', 1) |
| | parent = model.get_submodule(parent_name) |
| | else: |
| | parent = model |
| | attr_name = name |
| | fp8_locations.append((parent, attr_name, module)) |
| |
|
| | if not fp8_locations: |
| | yield |
| | return |
| |
|
| | |
| | for parent, attr_name, fp8_module in fp8_locations: |
| | linear = nn.Linear( |
| | fp8_module.in_features, |
| | fp8_module.out_features, |
| | bias=fp8_module.bias is not None, |
| | device=fp8_module.weight.device, |
| | dtype=fp8_module.weight.dtype, |
| | ) |
| | linear.weight = fp8_module.weight |
| | if fp8_module.bias is not None: |
| | linear.bias = fp8_module.bias |
| | setattr(parent, attr_name, linear) |
| |
|
| | try: |
| | yield |
| | finally: |
| | |
| | for parent, attr_name, fp8_module in fp8_locations: |
| | setattr(parent, attr_name, fp8_module) |
| |
|
| | |
| | |
| |
|
| | orig_model = model |
| | model = torch.compile(model, dynamic=False) |
| |
|
| | |
| | |
| | adam_betas = (args.adam_beta1, args.adam_beta2) |
| | optimizer = model.setup_optimizer( |
| | unembedding_lr=args.unembedding_lr * batch_lr_scale, |
| | embedding_lr=args.embedding_lr * batch_lr_scale, |
| | matrix_lr=args.matrix_lr * batch_lr_scale, |
| | weight_decay=weight_decay_scaled, |
| | adam_betas=adam_betas, |
| | scalar_lr=args.scalar_lr * batch_lr_scale, |
| | ) |
| |
|
| | if resuming: |
| | optimizer.load_state_dict(optimizer_data) |
| | del optimizer_data |
| |
|
| | |
| | |
| | dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] |
| | train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) |
| | build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) |
| | x, y, dataloader_state_dict = next(train_loader) |
| |
|
| | |
| | |
| |
|
| | |
| | def get_lr_multiplier(it): |
| | warmup_iters = round(args.warmup_ratio * num_iterations) |
| | warmdown_iters = round(args.warmdown_ratio * num_iterations) |
| | if it < warmup_iters: |
| | return (it + 1) / warmup_iters |
| | elif it <= num_iterations - warmdown_iters: |
| | return 1.0 |
| | else: |
| | progress = (num_iterations - it) / warmdown_iters |
| | return progress * 1.0 + (1 - progress) * args.final_lr_frac |
| |
|
| | |
| | def get_muon_momentum(it): |
| | frac = min(it / 300, 1) |
| | momentum = (1 - frac) * 0.85 + frac * 0.95 |
| | return momentum |
| |
|
| | |
| | def get_weight_decay(it): |
| | return weight_decay_scaled * (1 - it / num_iterations) |
| |
|
| | |
| | |
| |
|
| | if not resuming: |
| | step = 0 |
| | val_bpb = None |
| | min_val_bpb = float("inf") |
| | smooth_train_loss = 0 |
| | total_training_time = 0 |
| | else: |
| | step = meta_data["step"] |
| | loop_state = meta_data["loop_state"] |
| | val_bpb = meta_data["val_bpb"] |
| | min_val_bpb = loop_state["min_val_bpb"] |
| | smooth_train_loss = loop_state["smooth_train_loss"] |
| | total_training_time = loop_state["total_training_time"] |
| |
|
| | |
| | |
| | while True: |
| | last_step = step == num_iterations |
| | flops_so_far = num_flops_per_token * args.total_batch_size * step |
| |
|
| | |
| | if args.eval_every > 0 and (last_step or step % args.eval_every == 0): |
| | model.eval() |
| | val_loader = build_val_loader() |
| | eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) |
| | with disable_fp8(model), autocast_ctx: |
| | val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) |
| | print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") |
| | if val_bpb < min_val_bpb: |
| | min_val_bpb = val_bpb |
| | wandb_run.log({ |
| | "step": step, |
| | "total_training_flops": flops_so_far, |
| | "total_training_time": total_training_time, |
| | "val/bpb": val_bpb, |
| | }) |
| | model.train() |
| |
|
| | |
| | |
| | |
| | results = {} |
| | if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): |
| | model.eval() |
| | with disable_fp8(orig_model), autocast_ctx: |
| | results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) |
| | print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") |
| | wandb_run.log({ |
| | "step": step, |
| | "total_training_flops": flops_so_far, |
| | "core_metric": results["core_metric"], |
| | "centered_results": results["centered_results"], |
| | }) |
| | model.train() |
| |
|
| | |
| | |
| | if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): |
| | model.eval() |
| | prompts = [ |
| | "The capital of France is", |
| | "The chemical symbol of gold is", |
| | "If yesterday was Friday, then tomorrow will be", |
| | "The opposite of hot is", |
| | "The planets of the solar system are:", |
| | "My favorite color is", |
| | "If 5*x + 3 = 13, then x is", |
| | ] |
| | engine = Engine(orig_model, tokenizer) |
| | for prompt in prompts: |
| | tokens = tokenizer(prompt, prepend="<|bos|>") |
| | with disable_fp8(orig_model), autocast_ctx: |
| | sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) |
| | print0(tokenizer.decode(sample[0])) |
| | model.train() |
| |
|
| | |
| | if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): |
| | save_checkpoint( |
| | checkpoint_dir, |
| | step, |
| | orig_model.state_dict(), |
| | optimizer.state_dict(), |
| | { |
| | "step": step, |
| | "val_bpb": val_bpb, |
| | "model_config": model_config_kwargs, |
| | "user_config": user_config, |
| | "device_batch_size": args.device_batch_size, |
| | "max_seq_len": args.max_seq_len, |
| | "dataloader_state_dict": dataloader_state_dict, |
| | "loop_state": { |
| | "min_val_bpb": min_val_bpb, |
| | "smooth_train_loss": smooth_train_loss, |
| | "total_training_time": total_training_time, |
| | }, |
| | }, |
| | rank=ddp_rank, |
| | ) |
| |
|
| | |
| | if last_step: |
| | break |
| |
|
| | |
| | |
| | |
| | synchronize() |
| | t0 = time.time() |
| | for micro_step in range(grad_accum_steps): |
| | with autocast_ctx: |
| | loss = model(x, y) |
| | train_loss = loss.detach() |
| | loss = loss / grad_accum_steps |
| | loss.backward() |
| | x, y, dataloader_state_dict = next(train_loader) |
| | |
| | lrm = get_lr_multiplier(step) |
| | muon_momentum = get_muon_momentum(step) |
| | muon_weight_decay = get_weight_decay(step) |
| | for group in optimizer.param_groups: |
| | group["lr"] = group["initial_lr"] * lrm |
| | if group['kind'] == 'muon': |
| | group["momentum"] = muon_momentum |
| | group["weight_decay"] = muon_weight_decay |
| | optimizer.step() |
| | model.zero_grad(set_to_none=True) |
| | train_loss_f = train_loss.item() |
| | synchronize() |
| | t1 = time.time() |
| | dt = t1 - t0 |
| | |
| |
|
| | |
| | ema_beta = 0.9 |
| | smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f |
| | debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) |
| | pct_done = 100 * step / num_iterations |
| | tok_per_sec = int(args.total_batch_size / dt) |
| | flops_per_sec = num_flops_per_token * args.total_batch_size / dt |
| | mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) |
| | if step > 10: |
| | total_training_time += dt |
| | |
| | steps_done = step - 10 |
| | if steps_done > 0: |
| | avg_time_per_step = total_training_time / steps_done |
| | remaining_steps = num_iterations - step |
| | eta_seconds = remaining_steps * avg_time_per_step |
| | eta_str = f" | eta: {eta_seconds/60:.1f}m" |
| | else: |
| | eta_str = "" |
| | epoch = dataloader_state_dict["epoch"] |
| | print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") |
| | if step % 100 == 0: |
| | log_data = { |
| | "step": step, |
| | "total_training_flops": flops_so_far, |
| | "total_training_time": total_training_time, |
| | "train/loss": debiased_smooth_loss, |
| | "train/lrm": lrm, |
| | "train/dt": dt, |
| | "train/tok_per_sec": tok_per_sec, |
| | "train/mfu": mfu, |
| | "train/epoch": epoch, |
| | } |
| | wandb_run.log(log_data) |
| |
|
| | |
| | first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) |
| | step += 1 |
| |
|
| | |
| | |
| | |
| | if first_step_of_run: |
| | gc.collect() |
| | gc.freeze() |
| | gc.disable() |
| | elif step % 5000 == 0: |
| | gc.collect() |
| |
|
| | |
| | print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") |
| | print0(f"Total training time: {total_training_time/60:.2f}m") |
| | if val_bpb is not None: |
| | print0(f"Minimum validation bpb: {min_val_bpb:.6f}") |
| |
|
| | |
| | from nanochat.report import get_report |
| | get_report().log(section="Base model training", data=[ |
| | user_config, |
| | { |
| | "Number of parameters": num_params, |
| | "Number of FLOPs per token": f"{num_flops_per_token:e}", |
| | "Calculated number of iterations": num_iterations, |
| | "Number of training tokens": total_tokens, |
| | "Tokens : Scaling params ratio": args.total_batch_size * num_iterations / num_scaling_params, |
| | "DDP world size": ddp_world_size, |
| | "warmup_ratio": args.warmup_ratio, |
| | "warmdown_ratio": args.warmdown_ratio, |
| | "final_lr_frac": args.final_lr_frac, |
| | }, |
| | { |
| | "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, |
| | "Final validation bpb": val_bpb, |
| | "CORE metric estimate": results.get("core_metric", None), |
| | "MFU %": f"{mfu:.2f}%", |
| | "Total training flops": f"{flops_so_far:e}", |
| | "Total training time": f"{total_training_time/60:.2f}m", |
| | "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", |
| | } |
| | ]) |
| |
|
| | |
| | wandb_run.finish() |
| | compute_cleanup() |
| |
|