File size: 28,562 Bytes
f24563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 |
#!/usr/bin/env python3
"""
Main training script for LLM training on TPU v4-32.
Optimized for 128K token context length and 30-day training.
"""
import os
import sys
import time
import json
import argparse
import logging
import threading
import queue
from typing import Dict, Any, Optional, List, Tuple
import jax
import jax.numpy as jnp
import flax
import tensorflow as tf
import numpy as np
import sentencepiece as spm
from functools import partial
# Set up logging
logger = logging.getLogger(__name__)
# Try to import Weights & Biases
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
logger.warning("Weights & Biases not available. WandB logging will be disabled.")
WANDB_AVAILABLE = False
# Import local modules
from model.llm import LLM, LLMConfig
from data.tokenizer import SentencePieceTokenizer
from data.dataset import TextDataset, load_jsonl_dataset, StreamingDataset
from data.dataloader import TPUDataLoader
from training.trainer import Trainer, TrainingState, TrainingConfig as TrainerConfig
from training.optimizer import create_adamw_optimizer, create_lion_optimizer
from training.scheduler import create_linear_warmup_cosine_decay_schedule
from parallelism.data_parallel import DataParallel
from parallelism.tensor_parallel import TensorParallel
from config import create_config, Config
from utils.checkpoint import save_checkpoint, load_checkpoint
from utils.logging import setup_logger, log_metrics, create_summary_writer, log_metrics_to_tensorboard
from config import TrainingConfig, get_model_config
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Train LLM on TPU v4-32")
# Model parameters
parser.add_argument("--model_size", type=str, default="7b", choices=["7b", "13b", "70b", "175b", "600b"],
help="Model size")
# Training parameters
parser.add_argument("--learning_rate", type=float, default=3e-4,
help="Learning rate")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size per device")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="Number of steps to accumulate gradients")
parser.add_argument("--max_steps", type=int, default=100000,
help="Maximum number of training steps")
parser.add_argument("--warmup_steps", type=int, default=1000,
help="Number of warmup steps")
# Dataset parameters
parser.add_argument("--train_file", type=str, required=True,
help="Path to training file or HuggingFace dataset name")
parser.add_argument("--eval_file", type=str, default="",
help="Path to evaluation file or HuggingFace dataset name")
parser.add_argument("--tokenizer_file", type=str, required=True,
help="Path to tokenizer file")
parser.add_argument("--max_seq_length", type=int, default=131072,
help="Maximum sequence length (default: 128K tokens)")
parser.add_argument("--use_streaming", action="store_true", default=True,
help="Use streaming dataset for efficient training")
parser.add_argument("--streaming_buffer_size", type=int, default=10000,
help="Buffer size for streaming dataset")
parser.add_argument("--text_column", type=str, default="text",
help="Name of text column in dataset")
parser.add_argument("--preprocessing_num_workers", type=int, default=16,
help="Number of workers for dataset preprocessing")
# Parallelism parameters
parser.add_argument("--parallelism_type", type=str, default="data", choices=["data", "tensor"],
help="Type of parallelism")
parser.add_argument("--tensor_parallel_size", type=int, default=8,
help="Number of tensor parallel devices")
# Performance optimization parameters
parser.add_argument("--use_flash_attention", action="store_true", default=True,
help="Use flash attention for efficiency")
parser.add_argument("--use_gradient_checkpointing", action="store_true", default=True,
help="Use gradient checkpointing to save memory")
# Long context support parameters
parser.add_argument("--use_rope_scaling", action="store_true", default=True,
help="Use RoPE scaling for longer contexts")
parser.add_argument("--rope_scaling_factor", type=float, default=0.5,
help="Scaling factor for RoPE frequencies")
# Reasoning capabilities parameters
parser.add_argument("--use_reasoning_layer", action="store_true", default=True,
help="Use additional reasoning layers")
parser.add_argument("--num_reasoning_layers", type=int, default=None,
help="Number of additional reasoning layers (overrides model config)")
# Output parameters
parser.add_argument("--output_dir", type=str, default="output",
help="Output directory")
parser.add_argument("--logging_steps", type=int, default=100,
help="Number of steps between logging")
parser.add_argument("--save_steps", type=int, default=1000,
help="Number of steps between checkpoints")
parser.add_argument("--eval_steps", type=int, default=1000,
help="Number of steps between evaluations")
# Logging parameters
parser.add_argument("--use_wandb", action="store_true", default=True,
help="Use Weights & Biases for logging")
parser.add_argument("--wandb_project", type=str, default="llm-training",
help="Weights & Biases project name")
parser.add_argument("--wandb_entity", type=str, default=None,
help="Weights & Biases entity name")
parser.add_argument("--wandb_run_name", type=str, default=None,
help="Weights & Biases run name")
parser.add_argument("--log_memory_usage", action="store_true", default=True,
help="Log memory usage during training")
parser.add_argument("--profile_steps", type=int, default=100,
help="Number of steps between profiling")
# Miscellaneous parameters
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--resume_from_checkpoint", type=str, default="",
help="Path to checkpoint to resume from")
return parser.parse_args()
def create_config(args):
"""Create training configuration."""
# Get model configuration
model_config = get_model_config(args.model_size)
# Override model configuration with command line arguments if provided
if args.num_reasoning_layers is not None:
model_config.num_reasoning_layers = args.num_reasoning_layers
# Update model configuration with command line arguments
model_config.use_flash_attention = args.use_flash_attention
model_config.use_gradient_checkpointing = args.use_gradient_checkpointing
model_config.use_rope_scaling = args.use_rope_scaling
model_config.rope_scaling_factor = args.rope_scaling_factor
model_config.use_reasoning_layer = args.use_reasoning_layer
# Create training configuration
config = TrainingConfig(
output_dir=args.output_dir,
model_config=model_config,
# Training parameters
learning_rate=args.learning_rate,
batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_steps=args.max_steps,
warmup_steps=args.warmup_steps,
# Dataset parameters
train_file=args.train_file,
eval_file=args.eval_file,
tokenizer_file=args.tokenizer_file,
max_seq_length=args.max_seq_length,
# Parallelism parameters
parallelism_type=args.parallelism_type,
tensor_parallel_size=args.tensor_parallel_size,
# Performance optimization parameters
use_flash_attention=args.use_flash_attention,
use_gradient_checkpointing=args.use_gradient_checkpointing,
# Long context support parameters
use_rope_scaling=args.use_rope_scaling,
rope_scaling_factor=args.rope_scaling_factor,
# Reasoning capabilities parameters
use_reasoning_layer=args.use_reasoning_layer,
num_reasoning_layers=args.num_reasoning_layers if args.num_reasoning_layers is not None else model_config.num_reasoning_layers,
reasoning_intermediate_size=model_config.reasoning_intermediate_size,
# Logging parameters
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
# Miscellaneous parameters
seed=args.seed
)
return config
def setup_parallelism(config):
"""Set up parallelism."""
if config.parallelism_type == "data":
return DataParallel()
elif config.parallelism_type == "tensor":
return TensorParallel(num_tp=config.tensor_parallel_size)
else:
raise ValueError(f"Parallelism type {config.parallelism_type} not supported")
def create_model(config):
"""Create model."""
return LLM(config.model_config)
def create_optimizer(config, num_train_steps):
"""Create optimizer."""
# Create learning rate schedule
lr_schedule = create_linear_warmup_cosine_decay_schedule(
learning_rate=config.learning_rate,
warmup_steps=config.warmup_steps,
decay_steps=num_train_steps - config.warmup_steps,
final_learning_rate_factor=0.1
)
# Create optimizer
if config.optimizer == "adamw":
return create_adamw_optimizer(
learning_rate=lr_schedule,
weight_decay=config.weight_decay,
b1=config.adam_beta1,
b2=config.adam_beta2,
eps=config.adam_epsilon
)
elif config.optimizer == "lion":
return create_lion_optimizer(
learning_rate=lr_schedule,
weight_decay=config.weight_decay,
b1=config.adam_beta1,
b2=config.adam_beta2
)
else:
raise ValueError(f"Optimizer {config.optimizer} not supported")
def create_train_state(config, model, optimizer, rng):
"""Create training state."""
# Create dummy input
dummy_input = jnp.ones((1, 1), dtype=jnp.int32)
# Initialize model parameters
params_rng, dropout_rng = jax.random.split(rng)
params = model.init(params_rng, dummy_input)
# Create training state
return TrainingState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
dropout_rng=dropout_rng,
loss_scale=1.0
)
def load_tokenizer(config):
"""Load tokenizer."""
return SentencePieceTokenizer(config.tokenizer_file)
def load_dataset(config, tokenizer):
"""Load dataset with streaming support for efficient training."""
# Load training dataset
if config.use_streaming:
logger.info(f"Loading streaming dataset from {config.train_file}")
train_dataset = StreamingDataset(
tokenizer=tokenizer,
dataset_path=config.train_file,
max_seq_length=config.max_seq_length,
streaming=True,
buffer_size=config.streaming_buffer_size,
seed=config.seed,
text_column=config.text_column,
preprocessing_num_workers=config.preprocessing_num_workers
)
logger.info("Streaming dataset loaded successfully")
else:
logger.info(f"Loading standard dataset from {config.train_file}")
train_dataset = load_jsonl_dataset(
file_path=config.train_file,
tokenizer=tokenizer,
max_length=config.max_seq_length
)
logger.info(f"Dataset loaded with {len(train_dataset)} examples")
# Load evaluation dataset
eval_dataset = None
if config.eval_file:
if config.use_streaming:
logger.info(f"Loading streaming evaluation dataset from {config.eval_file}")
eval_dataset = StreamingDataset(
tokenizer=tokenizer,
dataset_path=config.eval_file,
max_seq_length=config.max_seq_length,
streaming=False, # Use non-streaming for evaluation for reproducibility
buffer_size=config.streaming_buffer_size,
seed=config.seed,
text_column=config.text_column,
preprocessing_num_workers=config.preprocessing_num_workers
)
logger.info("Streaming evaluation dataset loaded successfully")
else:
logger.info(f"Loading standard evaluation dataset from {config.eval_file}")
eval_dataset = load_jsonl_dataset(
file_path=config.eval_file,
tokenizer=tokenizer,
max_length=config.max_seq_length
)
logger.info(f"Evaluation dataset loaded with {len(eval_dataset)} examples")
return train_dataset, eval_dataset
def create_data_loaders(config, train_dataset, eval_dataset, tokenizer):
"""Create data loaders."""
# Create training data loader
train_loader = TPUDataLoader(
dataset=train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True,
pad_token_id=tokenizer.pad_token_id
)
# Create evaluation data loader
eval_loader = None
if eval_dataset is not None:
eval_loader = TPUDataLoader(
dataset=eval_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False,
pad_token_id=tokenizer.pad_token_id
)
return train_loader, eval_loader
def main():
"""Main function optimized for TPU v4-32."""
# Parse arguments
args = parse_args()
# Print TPU configuration information
print("TPU Configuration:")
print(f"Number of TPU devices: {jax.device_count()}")
print(f"TPU devices: {jax.devices()}")
print(f"JAX process index: {jax.process_index()}")
print(f"JAX process count: {jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}")
print(f"JAX local device count: {jax.local_device_count()}")
# Create configuration
config = create_config(args)
# Create output directory
os.makedirs(config.output_dir, exist_ok=True)
# Set up logging
logger = setup_logger(
name="tpu_train",
log_file=os.path.join(config.output_dir, "train.log")
)
# Log configuration
logger.info(f"Configuration: {config}")
# Initialize Weights & Biases if enabled
if args.use_wandb and WANDB_AVAILABLE:
logger.info("Initializing Weights & Biases")
wandb_run_name = args.wandb_run_name or f"{args.model_size}-{time.strftime('%Y%m%d-%H%M%S')}"
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=wandb_run_name,
config={
"model_size": args.model_size,
"learning_rate": args.learning_rate,
"batch_size": args.batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"max_steps": args.max_steps,
"warmup_steps": args.warmup_steps,
"max_seq_length": args.max_seq_length,
"parallelism_type": args.parallelism_type,
"tensor_parallel_size": args.tensor_parallel_size,
"use_flash_attention": args.use_flash_attention,
"use_gradient_checkpointing": args.use_gradient_checkpointing,
"use_rope_scaling": args.use_rope_scaling,
"rope_scaling_factor": args.rope_scaling_factor,
"use_reasoning_layer": args.use_reasoning_layer,
"num_reasoning_layers": args.num_reasoning_layers,
"use_streaming": args.use_streaming,
"streaming_buffer_size": args.streaming_buffer_size,
"text_column": args.text_column,
"preprocessing_num_workers": args.preprocessing_num_workers,
"seed": args.seed,
}
)
logger.info(f"Weights & Biases initialized with run name: {wandb_run_name}")
elif args.use_wandb and not WANDB_AVAILABLE:
logger.warning("Weights & Biases not available. Install wandb package to enable logging.")
else:
logger.info("Weights & Biases logging disabled.")
# Log hardware information
logger.info(f"Training on TPU v4-32 with {jax.device_count()} devices")
logger.info(f"Model size: {args.model_size} ({config.model_config.hidden_size} hidden size, "
f"{config.model_config.num_hidden_layers} layers)")
logger.info(f"Max sequence length: {args.max_seq_length} tokens")
logger.info(f"Batch size: {args.batch_size} per device, {args.batch_size * jax.device_count()} global")
logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
logger.info(f"Effective batch size: {args.batch_size * jax.device_count() * args.gradient_accumulation_steps}")
logger.info(f"Learning rate: {args.learning_rate}")
logger.info(f"Warmup steps: {args.warmup_steps}")
logger.info(f"Max steps: {args.max_steps}")
logger.info(f"Parallelism type: {args.parallelism_type}")
logger.info(f"Tensor parallel size: {args.tensor_parallel_size}")
logger.info(f"Using streaming dataset: {args.use_streaming}")
logger.info(f"Using flash attention: {args.use_flash_attention}")
logger.info(f"Using gradient checkpointing: {args.use_gradient_checkpointing}")
logger.info(f"Using RoPE scaling: {args.use_rope_scaling}")
logger.info(f"RoPE scaling factor: {args.rope_scaling_factor}")
logger.info(f"Using reasoning layer: {args.use_reasoning_layer}")
logger.info(f"Number of reasoning layers: {config.model_config.num_reasoning_layers}")
logger.info(f"Random seed: {args.seed}")
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Logging steps: {args.logging_steps}")
logger.info(f"Save steps: {args.save_steps}")
logger.info(f"Eval steps: {args.eval_steps}")
logger.info(f"Profile steps: {args.profile_steps}")
logger.info(f"Using Weights & Biases: {args.use_wandb and WANDB_AVAILABLE}")
logger.info(f"Logging memory usage: {args.log_memory_usage}")
# Calculate approximate model size
param_count = (
# Embedding parameters
config.model_config.vocab_size * config.model_config.hidden_size +
# Transformer layers
config.model_config.num_hidden_layers * (
# Self-attention
4 * config.model_config.hidden_size * config.model_config.hidden_size +
# Feed-forward
2 * config.model_config.hidden_size * config.model_config.intermediate_size +
# Layer normalization
4 * config.model_config.hidden_size
) +
# Reasoning layers if enabled
(config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * (
# Self-attention
4 * config.model_config.hidden_size * config.model_config.hidden_size +
# Feed-forward with larger hidden dimension
2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size +
# Layer normalization
4 * config.model_config.hidden_size
) +
# Final layer normalization
config.model_config.hidden_size +
# Output projection
config.model_config.hidden_size * config.model_config.vocab_size
)
# Log parameter count
logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters")
# Calculate memory requirements
bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4 # bfloat16 or float32
model_size_gb = param_count * bytes_per_param / 1e9
optimizer_size_gb = model_size_gb * 2 # Adam uses 2x model size for optimizer states
activation_size_gb = model_size_gb * 0.2 # Rough estimate for activations
total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb
# Log memory requirements
logger.info(f"Estimated memory requirements:")
logger.info(f" Model parameters: {model_size_gb:.2f} GB")
logger.info(f" Optimizer states: {optimizer_size_gb:.2f} GB")
logger.info(f" Activations: {activation_size_gb:.2f} GB")
logger.info(f" Total: {total_memory_gb:.2f} GB")
# Check if memory requirements exceed available TPU memory
tpu_memory_gb = 32 * jax.device_count() # Each TPU v4 has 32GB HBM
logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB")
if total_memory_gb > tpu_memory_gb * 0.9: # Leave 10% margin
logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)")
logger.warning("Consider enabling gradient checkpointing and using a smaller batch size")
# Calculate approximate model size
param_count = (
# Embedding parameters
config.model_config.vocab_size * config.model_config.hidden_size +
# Transformer layers
config.model_config.num_hidden_layers * (
# Self-attention
4 * config.model_config.hidden_size * config.model_config.hidden_size +
# Feed-forward
2 * config.model_config.hidden_size * config.model_config.intermediate_size +
# Layer normalization
4 * config.model_config.hidden_size
) +
# Reasoning layers if enabled
(config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * (
# Self-attention
4 * config.model_config.hidden_size * config.model_config.hidden_size +
# Feed-forward with larger hidden dimension
2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size +
# Layer normalization
4 * config.model_config.hidden_size
) +
# Final layer normalization
config.model_config.hidden_size +
# Output projection
config.model_config.hidden_size * config.model_config.vocab_size
)
# Log parameter count
logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters")
# Calculate memory requirements
bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4 # bfloat16 or float32
model_size_gb = param_count * bytes_per_param / 1e9
optimizer_size_gb = model_size_gb * 2 # Adam uses 2x model size for optimizer states
activation_size_gb = model_size_gb * 0.2 # Rough estimate for activations
total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb
# Log memory requirements
logger.info(f"Estimated memory requirements:")
logger.info(f" Model parameters: {model_size_gb:.2f} GB")
logger.info(f" Optimizer states: {optimizer_size_gb:.2f} GB")
logger.info(f" Activations: {activation_size_gb:.2f} GB")
logger.info(f" Total: {total_memory_gb:.2f} GB")
# Check if memory requirements exceed available TPU memory
tpu_memory_gb = 32 * jax.device_count() # Each TPU v4 has 32GB HBM
logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB")
if total_memory_gb > tpu_memory_gb * 0.9: # Leave 10% margin
logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)")
logger.warning("Consider enabling gradient checkpointing and using a smaller batch size")
# Set random seed
rng = jax.random.PRNGKey(config.seed)
# Measure initialization time
start_time = time.time()
# Set up parallelism with optimized configuration for TPU v4-32
parallel = setup_parallelism(config)
# Create model
model = create_model(config)
logger.info(f"Model created in {time.time() - start_time:.2f} seconds")
# Create optimizer with memory-efficient configuration
optimizer = create_optimizer(config, config.max_steps)
# Create training state
state_start_time = time.time()
state = create_train_state(config, model, optimizer, rng)
logger.info(f"Training state created in {time.time() - state_start_time:.2f} seconds")
# Shard parameters with optimized sharding strategy
shard_start_time = time.time()
state = state.replace(params=parallel.shard_params(state.params))
logger.info(f"Parameters sharded in {time.time() - shard_start_time:.2f} seconds")
# Load checkpoint if requested
if args.resume_from_checkpoint:
checkpoint_start_time = time.time()
state, step = load_checkpoint(args.resume_from_checkpoint, state)
logger.info(f"Checkpoint loaded in {time.time() - checkpoint_start_time:.2f} seconds")
# Load tokenizer
tokenizer_start_time = time.time()
tokenizer = load_tokenizer(config)
logger.info(f"Tokenizer loaded in {time.time() - tokenizer_start_time:.2f} seconds")
# Load dataset with optimized loading
dataset_start_time = time.time()
train_dataset, eval_dataset = load_dataset(config, tokenizer)
logger.info(f"Datasets loaded in {time.time() - dataset_start_time:.2f} seconds")
# Create data loaders with optimized configuration for TPU v4-32
dataloader_start_time = time.time()
train_loader, eval_loader = create_data_loaders(
config,
train_dataset,
eval_dataset,
tokenizer
)
logger.info(f"Data loaders created in {time.time() - dataloader_start_time:.2f} seconds")
# Create TensorBoard summary writer
summary_writer = create_summary_writer(
os.path.join(config.output_dir, "tensorboard")
)
# Create trainer configuration with optimized settings for TPU v4-32
trainer_config = TrainerConfig(
model_config=config.model_config,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_steps=config.warmup_steps,
max_steps=config.max_steps,
batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
max_grad_norm=config.max_grad_norm,
adam_beta1=config.adam_beta1,
adam_beta2=config.adam_beta2,
adam_epsilon=config.adam_epsilon,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
eval_steps=config.eval_steps,
output_dir=config.output_dir,
seed=config.seed,
dtype=config.dtype,
# Additional optimized settings for TPU v4-32
use_pjit=True, # Use pjit for better performance
use_scan=True, # Use scan for layer iteration
use_remat=config.model_config.use_gradient_checkpointing, # Use rematerialization for memory efficiency
use_sharded_optim=True, # Use sharded optimizer states
profile_steps=100, # Profile every 100 steps
async_checkpointing=True, # Use async checkpointing for better performance
)
# Create trainer
trainer = Trainer(
config=trainer_config,
model=model,
train_dataloader=train_loader,
eval_dataloader=eval_loader,
state=state,
parallel=parallel, # Pass parallelism object for optimized training
)
# Log total initialization time
logger.info(f"Total initialization time: {time.time() - start_time:.2f} seconds")
# Calculate estimated training time
steps_per_day = 24 * 60 * 60 / (5 * 60) # Assuming 5 minutes per 100 steps (rough estimate)
estimated_days = config.max_steps / steps_per_day
logger.info(f"Estimated training time: {estimated_days:.2f} days for {config.max_steps} steps")
# Train model with performance monitoring
try:
train_start_time = time.time()
trainer.train()
train_duration = time.time() - train_start_time
logger.info(f"Training completed in {train_duration / 3600:.2f} hours")
logger.info(f"Average training speed: {config.max_steps / train_duration:.2f} steps/second")
except Exception as e:
logger.error(f"Training failed with error: {e}")
raise
if __name__ == "__main__":
main()
|