|
|
|
""" |
|
Main training script for LLM training on TPU v4-32. |
|
Optimized for 128K token context length and 30-day training. |
|
""" |
|
|
|
import os |
|
import sys |
|
import time |
|
import json |
|
import argparse |
|
import logging |
|
import threading |
|
import queue |
|
from typing import Dict, Any, Optional, List, Tuple |
|
import jax |
|
import jax.numpy as jnp |
|
import flax |
|
import tensorflow as tf |
|
import numpy as np |
|
import sentencepiece as spm |
|
from functools import partial |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
import wandb |
|
WANDB_AVAILABLE = True |
|
except ImportError: |
|
logger.warning("Weights & Biases not available. WandB logging will be disabled.") |
|
WANDB_AVAILABLE = False |
|
|
|
|
|
from model.llm import LLM, LLMConfig |
|
from data.tokenizer import SentencePieceTokenizer |
|
from data.dataset import TextDataset, load_jsonl_dataset, StreamingDataset |
|
from data.dataloader import TPUDataLoader |
|
from training.trainer import Trainer, TrainingState, TrainingConfig as TrainerConfig |
|
from training.optimizer import create_adamw_optimizer, create_lion_optimizer |
|
from training.scheduler import create_linear_warmup_cosine_decay_schedule |
|
from parallelism.data_parallel import DataParallel |
|
from parallelism.tensor_parallel import TensorParallel |
|
from config import create_config, Config |
|
from utils.checkpoint import save_checkpoint, load_checkpoint |
|
from utils.logging import setup_logger, log_metrics, create_summary_writer, log_metrics_to_tensorboard |
|
from config import TrainingConfig, get_model_config |
|
|
|
|
|
def parse_args(): |
|
"""Parse command line arguments.""" |
|
parser = argparse.ArgumentParser(description="Train LLM on TPU v4-32") |
|
|
|
|
|
parser.add_argument("--model_size", type=str, default="7b", choices=["7b", "13b", "70b", "175b", "600b"], |
|
help="Model size") |
|
|
|
|
|
parser.add_argument("--learning_rate", type=float, default=3e-4, |
|
help="Learning rate") |
|
parser.add_argument("--batch_size", type=int, default=32, |
|
help="Batch size per device") |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, |
|
help="Number of steps to accumulate gradients") |
|
parser.add_argument("--max_steps", type=int, default=100000, |
|
help="Maximum number of training steps") |
|
parser.add_argument("--warmup_steps", type=int, default=1000, |
|
help="Number of warmup steps") |
|
|
|
|
|
parser.add_argument("--train_file", type=str, required=True, |
|
help="Path to training file or HuggingFace dataset name") |
|
parser.add_argument("--eval_file", type=str, default="", |
|
help="Path to evaluation file or HuggingFace dataset name") |
|
parser.add_argument("--tokenizer_file", type=str, required=True, |
|
help="Path to tokenizer file") |
|
parser.add_argument("--max_seq_length", type=int, default=131072, |
|
help="Maximum sequence length (default: 128K tokens)") |
|
parser.add_argument("--use_streaming", action="store_true", default=True, |
|
help="Use streaming dataset for efficient training") |
|
parser.add_argument("--streaming_buffer_size", type=int, default=10000, |
|
help="Buffer size for streaming dataset") |
|
parser.add_argument("--text_column", type=str, default="text", |
|
help="Name of text column in dataset") |
|
parser.add_argument("--preprocessing_num_workers", type=int, default=16, |
|
help="Number of workers for dataset preprocessing") |
|
|
|
|
|
parser.add_argument("--parallelism_type", type=str, default="data", choices=["data", "tensor"], |
|
help="Type of parallelism") |
|
parser.add_argument("--tensor_parallel_size", type=int, default=8, |
|
help="Number of tensor parallel devices") |
|
|
|
|
|
parser.add_argument("--use_flash_attention", action="store_true", default=True, |
|
help="Use flash attention for efficiency") |
|
parser.add_argument("--use_gradient_checkpointing", action="store_true", default=True, |
|
help="Use gradient checkpointing to save memory") |
|
|
|
|
|
parser.add_argument("--use_rope_scaling", action="store_true", default=True, |
|
help="Use RoPE scaling for longer contexts") |
|
parser.add_argument("--rope_scaling_factor", type=float, default=0.5, |
|
help="Scaling factor for RoPE frequencies") |
|
|
|
|
|
parser.add_argument("--use_reasoning_layer", action="store_true", default=True, |
|
help="Use additional reasoning layers") |
|
parser.add_argument("--num_reasoning_layers", type=int, default=None, |
|
help="Number of additional reasoning layers (overrides model config)") |
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="output", |
|
help="Output directory") |
|
parser.add_argument("--logging_steps", type=int, default=100, |
|
help="Number of steps between logging") |
|
parser.add_argument("--save_steps", type=int, default=1000, |
|
help="Number of steps between checkpoints") |
|
parser.add_argument("--eval_steps", type=int, default=1000, |
|
help="Number of steps between evaluations") |
|
|
|
|
|
parser.add_argument("--use_wandb", action="store_true", default=True, |
|
help="Use Weights & Biases for logging") |
|
parser.add_argument("--wandb_project", type=str, default="llm-training", |
|
help="Weights & Biases project name") |
|
parser.add_argument("--wandb_entity", type=str, default=None, |
|
help="Weights & Biases entity name") |
|
parser.add_argument("--wandb_run_name", type=str, default=None, |
|
help="Weights & Biases run name") |
|
parser.add_argument("--log_memory_usage", action="store_true", default=True, |
|
help="Log memory usage during training") |
|
parser.add_argument("--profile_steps", type=int, default=100, |
|
help="Number of steps between profiling") |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42, |
|
help="Random seed") |
|
parser.add_argument("--resume_from_checkpoint", type=str, default="", |
|
help="Path to checkpoint to resume from") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def create_config(args): |
|
"""Create training configuration.""" |
|
|
|
model_config = get_model_config(args.model_size) |
|
|
|
|
|
if args.num_reasoning_layers is not None: |
|
model_config.num_reasoning_layers = args.num_reasoning_layers |
|
|
|
|
|
model_config.use_flash_attention = args.use_flash_attention |
|
model_config.use_gradient_checkpointing = args.use_gradient_checkpointing |
|
model_config.use_rope_scaling = args.use_rope_scaling |
|
model_config.rope_scaling_factor = args.rope_scaling_factor |
|
model_config.use_reasoning_layer = args.use_reasoning_layer |
|
|
|
|
|
config = TrainingConfig( |
|
output_dir=args.output_dir, |
|
model_config=model_config, |
|
|
|
|
|
learning_rate=args.learning_rate, |
|
batch_size=args.batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
max_steps=args.max_steps, |
|
warmup_steps=args.warmup_steps, |
|
|
|
|
|
train_file=args.train_file, |
|
eval_file=args.eval_file, |
|
tokenizer_file=args.tokenizer_file, |
|
max_seq_length=args.max_seq_length, |
|
|
|
|
|
parallelism_type=args.parallelism_type, |
|
tensor_parallel_size=args.tensor_parallel_size, |
|
|
|
|
|
use_flash_attention=args.use_flash_attention, |
|
use_gradient_checkpointing=args.use_gradient_checkpointing, |
|
|
|
|
|
use_rope_scaling=args.use_rope_scaling, |
|
rope_scaling_factor=args.rope_scaling_factor, |
|
|
|
|
|
use_reasoning_layer=args.use_reasoning_layer, |
|
num_reasoning_layers=args.num_reasoning_layers if args.num_reasoning_layers is not None else model_config.num_reasoning_layers, |
|
reasoning_intermediate_size=model_config.reasoning_intermediate_size, |
|
|
|
|
|
logging_steps=args.logging_steps, |
|
save_steps=args.save_steps, |
|
eval_steps=args.eval_steps, |
|
|
|
|
|
seed=args.seed |
|
) |
|
|
|
return config |
|
|
|
|
|
def setup_parallelism(config): |
|
"""Set up parallelism.""" |
|
if config.parallelism_type == "data": |
|
return DataParallel() |
|
elif config.parallelism_type == "tensor": |
|
return TensorParallel(num_tp=config.tensor_parallel_size) |
|
else: |
|
raise ValueError(f"Parallelism type {config.parallelism_type} not supported") |
|
|
|
|
|
def create_model(config): |
|
"""Create model.""" |
|
return LLM(config.model_config) |
|
|
|
|
|
def create_optimizer(config, num_train_steps): |
|
"""Create optimizer.""" |
|
|
|
lr_schedule = create_linear_warmup_cosine_decay_schedule( |
|
learning_rate=config.learning_rate, |
|
warmup_steps=config.warmup_steps, |
|
decay_steps=num_train_steps - config.warmup_steps, |
|
final_learning_rate_factor=0.1 |
|
) |
|
|
|
|
|
if config.optimizer == "adamw": |
|
return create_adamw_optimizer( |
|
learning_rate=lr_schedule, |
|
weight_decay=config.weight_decay, |
|
b1=config.adam_beta1, |
|
b2=config.adam_beta2, |
|
eps=config.adam_epsilon |
|
) |
|
elif config.optimizer == "lion": |
|
return create_lion_optimizer( |
|
learning_rate=lr_schedule, |
|
weight_decay=config.weight_decay, |
|
b1=config.adam_beta1, |
|
b2=config.adam_beta2 |
|
) |
|
else: |
|
raise ValueError(f"Optimizer {config.optimizer} not supported") |
|
|
|
|
|
def create_train_state(config, model, optimizer, rng): |
|
"""Create training state.""" |
|
|
|
dummy_input = jnp.ones((1, 1), dtype=jnp.int32) |
|
|
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
params = model.init(params_rng, dummy_input) |
|
|
|
|
|
return TrainingState.create( |
|
apply_fn=model.apply, |
|
params=params, |
|
tx=optimizer, |
|
dropout_rng=dropout_rng, |
|
loss_scale=1.0 |
|
) |
|
|
|
|
|
def load_tokenizer(config): |
|
"""Load tokenizer.""" |
|
return SentencePieceTokenizer(config.tokenizer_file) |
|
|
|
|
|
def load_dataset(config, tokenizer): |
|
"""Load dataset with streaming support for efficient training.""" |
|
|
|
if config.use_streaming: |
|
logger.info(f"Loading streaming dataset from {config.train_file}") |
|
train_dataset = StreamingDataset( |
|
tokenizer=tokenizer, |
|
dataset_path=config.train_file, |
|
max_seq_length=config.max_seq_length, |
|
streaming=True, |
|
buffer_size=config.streaming_buffer_size, |
|
seed=config.seed, |
|
text_column=config.text_column, |
|
preprocessing_num_workers=config.preprocessing_num_workers |
|
) |
|
logger.info("Streaming dataset loaded successfully") |
|
else: |
|
logger.info(f"Loading standard dataset from {config.train_file}") |
|
train_dataset = load_jsonl_dataset( |
|
file_path=config.train_file, |
|
tokenizer=tokenizer, |
|
max_length=config.max_seq_length |
|
) |
|
logger.info(f"Dataset loaded with {len(train_dataset)} examples") |
|
|
|
|
|
eval_dataset = None |
|
if config.eval_file: |
|
if config.use_streaming: |
|
logger.info(f"Loading streaming evaluation dataset from {config.eval_file}") |
|
eval_dataset = StreamingDataset( |
|
tokenizer=tokenizer, |
|
dataset_path=config.eval_file, |
|
max_seq_length=config.max_seq_length, |
|
streaming=False, |
|
buffer_size=config.streaming_buffer_size, |
|
seed=config.seed, |
|
text_column=config.text_column, |
|
preprocessing_num_workers=config.preprocessing_num_workers |
|
) |
|
logger.info("Streaming evaluation dataset loaded successfully") |
|
else: |
|
logger.info(f"Loading standard evaluation dataset from {config.eval_file}") |
|
eval_dataset = load_jsonl_dataset( |
|
file_path=config.eval_file, |
|
tokenizer=tokenizer, |
|
max_length=config.max_seq_length |
|
) |
|
logger.info(f"Evaluation dataset loaded with {len(eval_dataset)} examples") |
|
|
|
return train_dataset, eval_dataset |
|
|
|
|
|
def create_data_loaders(config, train_dataset, eval_dataset, tokenizer): |
|
"""Create data loaders.""" |
|
|
|
train_loader = TPUDataLoader( |
|
dataset=train_dataset, |
|
batch_size=config.batch_size, |
|
shuffle=True, |
|
drop_last=True, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
|
|
eval_loader = None |
|
if eval_dataset is not None: |
|
eval_loader = TPUDataLoader( |
|
dataset=eval_dataset, |
|
batch_size=config.batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
return train_loader, eval_loader |
|
|
|
|
|
def main(): |
|
"""Main function optimized for TPU v4-32.""" |
|
|
|
args = parse_args() |
|
|
|
|
|
print("TPU Configuration:") |
|
print(f"Number of TPU devices: {jax.device_count()}") |
|
print(f"TPU devices: {jax.devices()}") |
|
print(f"JAX process index: {jax.process_index()}") |
|
print(f"JAX process count: {jax.process_count()}") |
|
print(f"JAX local devices: {jax.local_devices()}") |
|
print(f"JAX local device count: {jax.local_device_count()}") |
|
|
|
|
|
config = create_config(args) |
|
|
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
|
|
|
logger = setup_logger( |
|
name="tpu_train", |
|
log_file=os.path.join(config.output_dir, "train.log") |
|
) |
|
|
|
|
|
logger.info(f"Configuration: {config}") |
|
|
|
|
|
if args.use_wandb and WANDB_AVAILABLE: |
|
logger.info("Initializing Weights & Biases") |
|
wandb_run_name = args.wandb_run_name or f"{args.model_size}-{time.strftime('%Y%m%d-%H%M%S')}" |
|
wandb.init( |
|
project=args.wandb_project, |
|
entity=args.wandb_entity, |
|
name=wandb_run_name, |
|
config={ |
|
"model_size": args.model_size, |
|
"learning_rate": args.learning_rate, |
|
"batch_size": args.batch_size, |
|
"gradient_accumulation_steps": args.gradient_accumulation_steps, |
|
"max_steps": args.max_steps, |
|
"warmup_steps": args.warmup_steps, |
|
"max_seq_length": args.max_seq_length, |
|
"parallelism_type": args.parallelism_type, |
|
"tensor_parallel_size": args.tensor_parallel_size, |
|
"use_flash_attention": args.use_flash_attention, |
|
"use_gradient_checkpointing": args.use_gradient_checkpointing, |
|
"use_rope_scaling": args.use_rope_scaling, |
|
"rope_scaling_factor": args.rope_scaling_factor, |
|
"use_reasoning_layer": args.use_reasoning_layer, |
|
"num_reasoning_layers": args.num_reasoning_layers, |
|
"use_streaming": args.use_streaming, |
|
"streaming_buffer_size": args.streaming_buffer_size, |
|
"text_column": args.text_column, |
|
"preprocessing_num_workers": args.preprocessing_num_workers, |
|
"seed": args.seed, |
|
} |
|
) |
|
logger.info(f"Weights & Biases initialized with run name: {wandb_run_name}") |
|
elif args.use_wandb and not WANDB_AVAILABLE: |
|
logger.warning("Weights & Biases not available. Install wandb package to enable logging.") |
|
else: |
|
logger.info("Weights & Biases logging disabled.") |
|
|
|
|
|
logger.info(f"Training on TPU v4-32 with {jax.device_count()} devices") |
|
logger.info(f"Model size: {args.model_size} ({config.model_config.hidden_size} hidden size, " |
|
f"{config.model_config.num_hidden_layers} layers)") |
|
logger.info(f"Max sequence length: {args.max_seq_length} tokens") |
|
logger.info(f"Batch size: {args.batch_size} per device, {args.batch_size * jax.device_count()} global") |
|
logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}") |
|
logger.info(f"Effective batch size: {args.batch_size * jax.device_count() * args.gradient_accumulation_steps}") |
|
logger.info(f"Learning rate: {args.learning_rate}") |
|
logger.info(f"Warmup steps: {args.warmup_steps}") |
|
logger.info(f"Max steps: {args.max_steps}") |
|
logger.info(f"Parallelism type: {args.parallelism_type}") |
|
logger.info(f"Tensor parallel size: {args.tensor_parallel_size}") |
|
logger.info(f"Using streaming dataset: {args.use_streaming}") |
|
logger.info(f"Using flash attention: {args.use_flash_attention}") |
|
logger.info(f"Using gradient checkpointing: {args.use_gradient_checkpointing}") |
|
logger.info(f"Using RoPE scaling: {args.use_rope_scaling}") |
|
logger.info(f"RoPE scaling factor: {args.rope_scaling_factor}") |
|
logger.info(f"Using reasoning layer: {args.use_reasoning_layer}") |
|
logger.info(f"Number of reasoning layers: {config.model_config.num_reasoning_layers}") |
|
logger.info(f"Random seed: {args.seed}") |
|
logger.info(f"Output directory: {args.output_dir}") |
|
logger.info(f"Logging steps: {args.logging_steps}") |
|
logger.info(f"Save steps: {args.save_steps}") |
|
logger.info(f"Eval steps: {args.eval_steps}") |
|
logger.info(f"Profile steps: {args.profile_steps}") |
|
logger.info(f"Using Weights & Biases: {args.use_wandb and WANDB_AVAILABLE}") |
|
logger.info(f"Logging memory usage: {args.log_memory_usage}") |
|
|
|
|
|
param_count = ( |
|
|
|
config.model_config.vocab_size * config.model_config.hidden_size + |
|
|
|
config.model_config.num_hidden_layers * ( |
|
|
|
4 * config.model_config.hidden_size * config.model_config.hidden_size + |
|
|
|
2 * config.model_config.hidden_size * config.model_config.intermediate_size + |
|
|
|
4 * config.model_config.hidden_size |
|
) + |
|
|
|
(config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * ( |
|
|
|
4 * config.model_config.hidden_size * config.model_config.hidden_size + |
|
|
|
2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size + |
|
|
|
4 * config.model_config.hidden_size |
|
) + |
|
|
|
config.model_config.hidden_size + |
|
|
|
config.model_config.hidden_size * config.model_config.vocab_size |
|
) |
|
|
|
|
|
logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters") |
|
|
|
|
|
bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4 |
|
model_size_gb = param_count * bytes_per_param / 1e9 |
|
optimizer_size_gb = model_size_gb * 2 |
|
activation_size_gb = model_size_gb * 0.2 |
|
total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb |
|
|
|
|
|
logger.info(f"Estimated memory requirements:") |
|
logger.info(f" Model parameters: {model_size_gb:.2f} GB") |
|
logger.info(f" Optimizer states: {optimizer_size_gb:.2f} GB") |
|
logger.info(f" Activations: {activation_size_gb:.2f} GB") |
|
logger.info(f" Total: {total_memory_gb:.2f} GB") |
|
|
|
|
|
tpu_memory_gb = 32 * jax.device_count() |
|
logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB") |
|
if total_memory_gb > tpu_memory_gb * 0.9: |
|
logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)") |
|
logger.warning("Consider enabling gradient checkpointing and using a smaller batch size") |
|
|
|
|
|
param_count = ( |
|
|
|
config.model_config.vocab_size * config.model_config.hidden_size + |
|
|
|
config.model_config.num_hidden_layers * ( |
|
|
|
4 * config.model_config.hidden_size * config.model_config.hidden_size + |
|
|
|
2 * config.model_config.hidden_size * config.model_config.intermediate_size + |
|
|
|
4 * config.model_config.hidden_size |
|
) + |
|
|
|
(config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * ( |
|
|
|
4 * config.model_config.hidden_size * config.model_config.hidden_size + |
|
|
|
2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size + |
|
|
|
4 * config.model_config.hidden_size |
|
) + |
|
|
|
config.model_config.hidden_size + |
|
|
|
config.model_config.hidden_size * config.model_config.vocab_size |
|
) |
|
|
|
|
|
logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters") |
|
|
|
|
|
bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4 |
|
model_size_gb = param_count * bytes_per_param / 1e9 |
|
optimizer_size_gb = model_size_gb * 2 |
|
activation_size_gb = model_size_gb * 0.2 |
|
total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb |
|
|
|
|
|
logger.info(f"Estimated memory requirements:") |
|
logger.info(f" Model parameters: {model_size_gb:.2f} GB") |
|
logger.info(f" Optimizer states: {optimizer_size_gb:.2f} GB") |
|
logger.info(f" Activations: {activation_size_gb:.2f} GB") |
|
logger.info(f" Total: {total_memory_gb:.2f} GB") |
|
|
|
|
|
tpu_memory_gb = 32 * jax.device_count() |
|
logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB") |
|
if total_memory_gb > tpu_memory_gb * 0.9: |
|
logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)") |
|
logger.warning("Consider enabling gradient checkpointing and using a smaller batch size") |
|
|
|
|
|
rng = jax.random.PRNGKey(config.seed) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
parallel = setup_parallelism(config) |
|
|
|
|
|
model = create_model(config) |
|
logger.info(f"Model created in {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
optimizer = create_optimizer(config, config.max_steps) |
|
|
|
|
|
state_start_time = time.time() |
|
state = create_train_state(config, model, optimizer, rng) |
|
logger.info(f"Training state created in {time.time() - state_start_time:.2f} seconds") |
|
|
|
|
|
shard_start_time = time.time() |
|
state = state.replace(params=parallel.shard_params(state.params)) |
|
logger.info(f"Parameters sharded in {time.time() - shard_start_time:.2f} seconds") |
|
|
|
|
|
if args.resume_from_checkpoint: |
|
checkpoint_start_time = time.time() |
|
state, step = load_checkpoint(args.resume_from_checkpoint, state) |
|
logger.info(f"Checkpoint loaded in {time.time() - checkpoint_start_time:.2f} seconds") |
|
|
|
|
|
tokenizer_start_time = time.time() |
|
tokenizer = load_tokenizer(config) |
|
logger.info(f"Tokenizer loaded in {time.time() - tokenizer_start_time:.2f} seconds") |
|
|
|
|
|
dataset_start_time = time.time() |
|
train_dataset, eval_dataset = load_dataset(config, tokenizer) |
|
logger.info(f"Datasets loaded in {time.time() - dataset_start_time:.2f} seconds") |
|
|
|
|
|
dataloader_start_time = time.time() |
|
train_loader, eval_loader = create_data_loaders( |
|
config, |
|
train_dataset, |
|
eval_dataset, |
|
tokenizer |
|
) |
|
logger.info(f"Data loaders created in {time.time() - dataloader_start_time:.2f} seconds") |
|
|
|
|
|
summary_writer = create_summary_writer( |
|
os.path.join(config.output_dir, "tensorboard") |
|
) |
|
|
|
|
|
trainer_config = TrainerConfig( |
|
model_config=config.model_config, |
|
learning_rate=config.learning_rate, |
|
weight_decay=config.weight_decay, |
|
warmup_steps=config.warmup_steps, |
|
max_steps=config.max_steps, |
|
batch_size=config.batch_size, |
|
gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
max_grad_norm=config.max_grad_norm, |
|
adam_beta1=config.adam_beta1, |
|
adam_beta2=config.adam_beta2, |
|
adam_epsilon=config.adam_epsilon, |
|
logging_steps=config.logging_steps, |
|
save_steps=config.save_steps, |
|
eval_steps=config.eval_steps, |
|
output_dir=config.output_dir, |
|
seed=config.seed, |
|
dtype=config.dtype, |
|
|
|
use_pjit=True, |
|
use_scan=True, |
|
use_remat=config.model_config.use_gradient_checkpointing, |
|
use_sharded_optim=True, |
|
profile_steps=100, |
|
async_checkpointing=True, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
config=trainer_config, |
|
model=model, |
|
train_dataloader=train_loader, |
|
eval_dataloader=eval_loader, |
|
state=state, |
|
parallel=parallel, |
|
) |
|
|
|
|
|
logger.info(f"Total initialization time: {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
steps_per_day = 24 * 60 * 60 / (5 * 60) |
|
estimated_days = config.max_steps / steps_per_day |
|
logger.info(f"Estimated training time: {estimated_days:.2f} days for {config.max_steps} steps") |
|
|
|
|
|
try: |
|
train_start_time = time.time() |
|
trainer.train() |
|
train_duration = time.time() - train_start_time |
|
logger.info(f"Training completed in {train_duration / 3600:.2f} hours") |
|
logger.info(f"Average training speed: {config.max_steps / train_duration:.2f} steps/second") |
|
except Exception as e: |
|
logger.error(f"Training failed with error: {e}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|