File size: 28,562 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
#!/usr/bin/env python3
"""
Main training script for LLM training on TPU v4-32.
Optimized for 128K token context length and 30-day training.
"""

import os
import sys
import time
import json
import argparse
import logging
import threading
import queue
from typing import Dict, Any, Optional, List, Tuple
import jax
import jax.numpy as jnp
import flax
import tensorflow as tf
import numpy as np
import sentencepiece as spm
from functools import partial

# Set up logging
logger = logging.getLogger(__name__)

# Try to import Weights & Biases
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    logger.warning("Weights & Biases not available. WandB logging will be disabled.")
    WANDB_AVAILABLE = False

# Import local modules
from model.llm import LLM, LLMConfig
from data.tokenizer import SentencePieceTokenizer
from data.dataset import TextDataset, load_jsonl_dataset, StreamingDataset
from data.dataloader import TPUDataLoader
from training.trainer import Trainer, TrainingState, TrainingConfig as TrainerConfig
from training.optimizer import create_adamw_optimizer, create_lion_optimizer
from training.scheduler import create_linear_warmup_cosine_decay_schedule
from parallelism.data_parallel import DataParallel
from parallelism.tensor_parallel import TensorParallel
from config import create_config, Config
from utils.checkpoint import save_checkpoint, load_checkpoint
from utils.logging import setup_logger, log_metrics, create_summary_writer, log_metrics_to_tensorboard
from config import TrainingConfig, get_model_config


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Train LLM on TPU v4-32")

    # Model parameters
    parser.add_argument("--model_size", type=str, default="7b", choices=["7b", "13b", "70b", "175b", "600b"],
                        help="Model size")

    # Training parameters
    parser.add_argument("--learning_rate", type=float, default=3e-4,
                        help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Batch size per device")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
                        help="Number of steps to accumulate gradients")
    parser.add_argument("--max_steps", type=int, default=100000,
                        help="Maximum number of training steps")
    parser.add_argument("--warmup_steps", type=int, default=1000,
                        help="Number of warmup steps")

    # Dataset parameters
    parser.add_argument("--train_file", type=str, required=True,
                        help="Path to training file or HuggingFace dataset name")
    parser.add_argument("--eval_file", type=str, default="",
                        help="Path to evaluation file or HuggingFace dataset name")
    parser.add_argument("--tokenizer_file", type=str, required=True,
                        help="Path to tokenizer file")
    parser.add_argument("--max_seq_length", type=int, default=131072,
                        help="Maximum sequence length (default: 128K tokens)")
    parser.add_argument("--use_streaming", action="store_true", default=True,
                        help="Use streaming dataset for efficient training")
    parser.add_argument("--streaming_buffer_size", type=int, default=10000,
                        help="Buffer size for streaming dataset")
    parser.add_argument("--text_column", type=str, default="text",
                        help="Name of text column in dataset")
    parser.add_argument("--preprocessing_num_workers", type=int, default=16,
                        help="Number of workers for dataset preprocessing")

    # Parallelism parameters
    parser.add_argument("--parallelism_type", type=str, default="data", choices=["data", "tensor"],
                        help="Type of parallelism")
    parser.add_argument("--tensor_parallel_size", type=int, default=8,
                        help="Number of tensor parallel devices")

    # Performance optimization parameters
    parser.add_argument("--use_flash_attention", action="store_true", default=True,
                        help="Use flash attention for efficiency")
    parser.add_argument("--use_gradient_checkpointing", action="store_true", default=True,
                        help="Use gradient checkpointing to save memory")

    # Long context support parameters
    parser.add_argument("--use_rope_scaling", action="store_true", default=True,
                        help="Use RoPE scaling for longer contexts")
    parser.add_argument("--rope_scaling_factor", type=float, default=0.5,
                        help="Scaling factor for RoPE frequencies")

    # Reasoning capabilities parameters
    parser.add_argument("--use_reasoning_layer", action="store_true", default=True,
                        help="Use additional reasoning layers")
    parser.add_argument("--num_reasoning_layers", type=int, default=None,
                        help="Number of additional reasoning layers (overrides model config)")

    # Output parameters
    parser.add_argument("--output_dir", type=str, default="output",
                        help="Output directory")
    parser.add_argument("--logging_steps", type=int, default=100,
                        help="Number of steps between logging")
    parser.add_argument("--save_steps", type=int, default=1000,
                        help="Number of steps between checkpoints")
    parser.add_argument("--eval_steps", type=int, default=1000,
                        help="Number of steps between evaluations")

    # Logging parameters
    parser.add_argument("--use_wandb", action="store_true", default=True,
                        help="Use Weights & Biases for logging")
    parser.add_argument("--wandb_project", type=str, default="llm-training",
                        help="Weights & Biases project name")
    parser.add_argument("--wandb_entity", type=str, default=None,
                        help="Weights & Biases entity name")
    parser.add_argument("--wandb_run_name", type=str, default=None,
                        help="Weights & Biases run name")
    parser.add_argument("--log_memory_usage", action="store_true", default=True,
                        help="Log memory usage during training")
    parser.add_argument("--profile_steps", type=int, default=100,
                        help="Number of steps between profiling")

    # Miscellaneous parameters
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--resume_from_checkpoint", type=str, default="",
                        help="Path to checkpoint to resume from")

    return parser.parse_args()


def create_config(args):
    """Create training configuration."""
    # Get model configuration
    model_config = get_model_config(args.model_size)

    # Override model configuration with command line arguments if provided
    if args.num_reasoning_layers is not None:
        model_config.num_reasoning_layers = args.num_reasoning_layers

    # Update model configuration with command line arguments
    model_config.use_flash_attention = args.use_flash_attention
    model_config.use_gradient_checkpointing = args.use_gradient_checkpointing
    model_config.use_rope_scaling = args.use_rope_scaling
    model_config.rope_scaling_factor = args.rope_scaling_factor
    model_config.use_reasoning_layer = args.use_reasoning_layer

    # Create training configuration
    config = TrainingConfig(
        output_dir=args.output_dir,
        model_config=model_config,

        # Training parameters
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_steps=args.max_steps,
        warmup_steps=args.warmup_steps,

        # Dataset parameters
        train_file=args.train_file,
        eval_file=args.eval_file,
        tokenizer_file=args.tokenizer_file,
        max_seq_length=args.max_seq_length,

        # Parallelism parameters
        parallelism_type=args.parallelism_type,
        tensor_parallel_size=args.tensor_parallel_size,

        # Performance optimization parameters
        use_flash_attention=args.use_flash_attention,
        use_gradient_checkpointing=args.use_gradient_checkpointing,

        # Long context support parameters
        use_rope_scaling=args.use_rope_scaling,
        rope_scaling_factor=args.rope_scaling_factor,

        # Reasoning capabilities parameters
        use_reasoning_layer=args.use_reasoning_layer,
        num_reasoning_layers=args.num_reasoning_layers if args.num_reasoning_layers is not None else model_config.num_reasoning_layers,
        reasoning_intermediate_size=model_config.reasoning_intermediate_size,

        # Logging parameters
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps,

        # Miscellaneous parameters
        seed=args.seed
    )

    return config


def setup_parallelism(config):
    """Set up parallelism."""
    if config.parallelism_type == "data":
        return DataParallel()
    elif config.parallelism_type == "tensor":
        return TensorParallel(num_tp=config.tensor_parallel_size)
    else:
        raise ValueError(f"Parallelism type {config.parallelism_type} not supported")


def create_model(config):
    """Create model."""
    return LLM(config.model_config)


def create_optimizer(config, num_train_steps):
    """Create optimizer."""
    # Create learning rate schedule
    lr_schedule = create_linear_warmup_cosine_decay_schedule(
        learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps,
        decay_steps=num_train_steps - config.warmup_steps,
        final_learning_rate_factor=0.1
    )

    # Create optimizer
    if config.optimizer == "adamw":
        return create_adamw_optimizer(
            learning_rate=lr_schedule,
            weight_decay=config.weight_decay,
            b1=config.adam_beta1,
            b2=config.adam_beta2,
            eps=config.adam_epsilon
        )
    elif config.optimizer == "lion":
        return create_lion_optimizer(
            learning_rate=lr_schedule,
            weight_decay=config.weight_decay,
            b1=config.adam_beta1,
            b2=config.adam_beta2
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer} not supported")


def create_train_state(config, model, optimizer, rng):
    """Create training state."""
    # Create dummy input
    dummy_input = jnp.ones((1, 1), dtype=jnp.int32)

    # Initialize model parameters
    params_rng, dropout_rng = jax.random.split(rng)
    params = model.init(params_rng, dummy_input)

    # Create training state
    return TrainingState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        dropout_rng=dropout_rng,
        loss_scale=1.0
    )


def load_tokenizer(config):
    """Load tokenizer."""
    return SentencePieceTokenizer(config.tokenizer_file)


def load_dataset(config, tokenizer):
    """Load dataset with streaming support for efficient training."""
    # Load training dataset
    if config.use_streaming:
        logger.info(f"Loading streaming dataset from {config.train_file}")
        train_dataset = StreamingDataset(
            tokenizer=tokenizer,
            dataset_path=config.train_file,
            max_seq_length=config.max_seq_length,
            streaming=True,
            buffer_size=config.streaming_buffer_size,
            seed=config.seed,
            text_column=config.text_column,
            preprocessing_num_workers=config.preprocessing_num_workers
        )
        logger.info("Streaming dataset loaded successfully")
    else:
        logger.info(f"Loading standard dataset from {config.train_file}")
        train_dataset = load_jsonl_dataset(
            file_path=config.train_file,
            tokenizer=tokenizer,
            max_length=config.max_seq_length
        )
        logger.info(f"Dataset loaded with {len(train_dataset)} examples")

    # Load evaluation dataset
    eval_dataset = None
    if config.eval_file:
        if config.use_streaming:
            logger.info(f"Loading streaming evaluation dataset from {config.eval_file}")
            eval_dataset = StreamingDataset(
                tokenizer=tokenizer,
                dataset_path=config.eval_file,
                max_seq_length=config.max_seq_length,
                streaming=False,  # Use non-streaming for evaluation for reproducibility
                buffer_size=config.streaming_buffer_size,
                seed=config.seed,
                text_column=config.text_column,
                preprocessing_num_workers=config.preprocessing_num_workers
            )
            logger.info("Streaming evaluation dataset loaded successfully")
        else:
            logger.info(f"Loading standard evaluation dataset from {config.eval_file}")
            eval_dataset = load_jsonl_dataset(
                file_path=config.eval_file,
                tokenizer=tokenizer,
                max_length=config.max_seq_length
            )
            logger.info(f"Evaluation dataset loaded with {len(eval_dataset)} examples")

    return train_dataset, eval_dataset


def create_data_loaders(config, train_dataset, eval_dataset, tokenizer):
    """Create data loaders."""
    # Create training data loader
    train_loader = TPUDataLoader(
        dataset=train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        drop_last=True,
        pad_token_id=tokenizer.pad_token_id
    )

    # Create evaluation data loader
    eval_loader = None
    if eval_dataset is not None:
        eval_loader = TPUDataLoader(
            dataset=eval_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            drop_last=False,
            pad_token_id=tokenizer.pad_token_id
        )

    return train_loader, eval_loader


def main():
    """Main function optimized for TPU v4-32."""
    # Parse arguments
    args = parse_args()

    # Print TPU configuration information
    print("TPU Configuration:")
    print(f"Number of TPU devices: {jax.device_count()}")
    print(f"TPU devices: {jax.devices()}")
    print(f"JAX process index: {jax.process_index()}")
    print(f"JAX process count: {jax.process_count()}")
    print(f"JAX local devices: {jax.local_devices()}")
    print(f"JAX local device count: {jax.local_device_count()}")

    # Create configuration
    config = create_config(args)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Set up logging
    logger = setup_logger(
        name="tpu_train",
        log_file=os.path.join(config.output_dir, "train.log")
    )

    # Log configuration
    logger.info(f"Configuration: {config}")

    # Initialize Weights & Biases if enabled
    if args.use_wandb and WANDB_AVAILABLE:
        logger.info("Initializing Weights & Biases")
        wandb_run_name = args.wandb_run_name or f"{args.model_size}-{time.strftime('%Y%m%d-%H%M%S')}"
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=wandb_run_name,
            config={
                "model_size": args.model_size,
                "learning_rate": args.learning_rate,
                "batch_size": args.batch_size,
                "gradient_accumulation_steps": args.gradient_accumulation_steps,
                "max_steps": args.max_steps,
                "warmup_steps": args.warmup_steps,
                "max_seq_length": args.max_seq_length,
                "parallelism_type": args.parallelism_type,
                "tensor_parallel_size": args.tensor_parallel_size,
                "use_flash_attention": args.use_flash_attention,
                "use_gradient_checkpointing": args.use_gradient_checkpointing,
                "use_rope_scaling": args.use_rope_scaling,
                "rope_scaling_factor": args.rope_scaling_factor,
                "use_reasoning_layer": args.use_reasoning_layer,
                "num_reasoning_layers": args.num_reasoning_layers,
                "use_streaming": args.use_streaming,
                "streaming_buffer_size": args.streaming_buffer_size,
                "text_column": args.text_column,
                "preprocessing_num_workers": args.preprocessing_num_workers,
                "seed": args.seed,
            }
        )
        logger.info(f"Weights & Biases initialized with run name: {wandb_run_name}")
    elif args.use_wandb and not WANDB_AVAILABLE:
        logger.warning("Weights & Biases not available. Install wandb package to enable logging.")
    else:
        logger.info("Weights & Biases logging disabled.")

    # Log hardware information
    logger.info(f"Training on TPU v4-32 with {jax.device_count()} devices")
    logger.info(f"Model size: {args.model_size} ({config.model_config.hidden_size} hidden size, "
               f"{config.model_config.num_hidden_layers} layers)")
    logger.info(f"Max sequence length: {args.max_seq_length} tokens")
    logger.info(f"Batch size: {args.batch_size} per device, {args.batch_size * jax.device_count()} global")
    logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
    logger.info(f"Effective batch size: {args.batch_size * jax.device_count() * args.gradient_accumulation_steps}")
    logger.info(f"Learning rate: {args.learning_rate}")
    logger.info(f"Warmup steps: {args.warmup_steps}")
    logger.info(f"Max steps: {args.max_steps}")
    logger.info(f"Parallelism type: {args.parallelism_type}")
    logger.info(f"Tensor parallel size: {args.tensor_parallel_size}")
    logger.info(f"Using streaming dataset: {args.use_streaming}")
    logger.info(f"Using flash attention: {args.use_flash_attention}")
    logger.info(f"Using gradient checkpointing: {args.use_gradient_checkpointing}")
    logger.info(f"Using RoPE scaling: {args.use_rope_scaling}")
    logger.info(f"RoPE scaling factor: {args.rope_scaling_factor}")
    logger.info(f"Using reasoning layer: {args.use_reasoning_layer}")
    logger.info(f"Number of reasoning layers: {config.model_config.num_reasoning_layers}")
    logger.info(f"Random seed: {args.seed}")
    logger.info(f"Output directory: {args.output_dir}")
    logger.info(f"Logging steps: {args.logging_steps}")
    logger.info(f"Save steps: {args.save_steps}")
    logger.info(f"Eval steps: {args.eval_steps}")
    logger.info(f"Profile steps: {args.profile_steps}")
    logger.info(f"Using Weights & Biases: {args.use_wandb and WANDB_AVAILABLE}")
    logger.info(f"Logging memory usage: {args.log_memory_usage}")

    # Calculate approximate model size
    param_count = (
        # Embedding parameters
        config.model_config.vocab_size * config.model_config.hidden_size +
        # Transformer layers
        config.model_config.num_hidden_layers * (
            # Self-attention
            4 * config.model_config.hidden_size * config.model_config.hidden_size +
            # Feed-forward
            2 * config.model_config.hidden_size * config.model_config.intermediate_size +
            # Layer normalization
            4 * config.model_config.hidden_size
        ) +
        # Reasoning layers if enabled
        (config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * (
            # Self-attention
            4 * config.model_config.hidden_size * config.model_config.hidden_size +
            # Feed-forward with larger hidden dimension
            2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size +
            # Layer normalization
            4 * config.model_config.hidden_size
        ) +
        # Final layer normalization
        config.model_config.hidden_size +
        # Output projection
        config.model_config.hidden_size * config.model_config.vocab_size
    )

    # Log parameter count
    logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters")

    # Calculate memory requirements
    bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4  # bfloat16 or float32
    model_size_gb = param_count * bytes_per_param / 1e9
    optimizer_size_gb = model_size_gb * 2  # Adam uses 2x model size for optimizer states
    activation_size_gb = model_size_gb * 0.2  # Rough estimate for activations
    total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb

    # Log memory requirements
    logger.info(f"Estimated memory requirements:")
    logger.info(f"  Model parameters: {model_size_gb:.2f} GB")
    logger.info(f"  Optimizer states: {optimizer_size_gb:.2f} GB")
    logger.info(f"  Activations: {activation_size_gb:.2f} GB")
    logger.info(f"  Total: {total_memory_gb:.2f} GB")

    # Check if memory requirements exceed available TPU memory
    tpu_memory_gb = 32 * jax.device_count()  # Each TPU v4 has 32GB HBM
    logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB")
    if total_memory_gb > tpu_memory_gb * 0.9:  # Leave 10% margin
        logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)")
        logger.warning("Consider enabling gradient checkpointing and using a smaller batch size")

    # Calculate approximate model size
    param_count = (
        # Embedding parameters
        config.model_config.vocab_size * config.model_config.hidden_size +
        # Transformer layers
        config.model_config.num_hidden_layers * (
            # Self-attention
            4 * config.model_config.hidden_size * config.model_config.hidden_size +
            # Feed-forward
            2 * config.model_config.hidden_size * config.model_config.intermediate_size +
            # Layer normalization
            4 * config.model_config.hidden_size
        ) +
        # Reasoning layers if enabled
        (config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * (
            # Self-attention
            4 * config.model_config.hidden_size * config.model_config.hidden_size +
            # Feed-forward with larger hidden dimension
            2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size +
            # Layer normalization
            4 * config.model_config.hidden_size
        ) +
        # Final layer normalization
        config.model_config.hidden_size +
        # Output projection
        config.model_config.hidden_size * config.model_config.vocab_size
    )

    # Log parameter count
    logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters")

    # Calculate memory requirements
    bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4  # bfloat16 or float32
    model_size_gb = param_count * bytes_per_param / 1e9
    optimizer_size_gb = model_size_gb * 2  # Adam uses 2x model size for optimizer states
    activation_size_gb = model_size_gb * 0.2  # Rough estimate for activations
    total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb

    # Log memory requirements
    logger.info(f"Estimated memory requirements:")
    logger.info(f"  Model parameters: {model_size_gb:.2f} GB")
    logger.info(f"  Optimizer states: {optimizer_size_gb:.2f} GB")
    logger.info(f"  Activations: {activation_size_gb:.2f} GB")
    logger.info(f"  Total: {total_memory_gb:.2f} GB")

    # Check if memory requirements exceed available TPU memory
    tpu_memory_gb = 32 * jax.device_count()  # Each TPU v4 has 32GB HBM
    logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB")
    if total_memory_gb > tpu_memory_gb * 0.9:  # Leave 10% margin
        logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)")
        logger.warning("Consider enabling gradient checkpointing and using a smaller batch size")

    # Set random seed
    rng = jax.random.PRNGKey(config.seed)

    # Measure initialization time
    start_time = time.time()

    # Set up parallelism with optimized configuration for TPU v4-32
    parallel = setup_parallelism(config)

    # Create model
    model = create_model(config)
    logger.info(f"Model created in {time.time() - start_time:.2f} seconds")

    # Create optimizer with memory-efficient configuration
    optimizer = create_optimizer(config, config.max_steps)

    # Create training state
    state_start_time = time.time()
    state = create_train_state(config, model, optimizer, rng)
    logger.info(f"Training state created in {time.time() - state_start_time:.2f} seconds")

    # Shard parameters with optimized sharding strategy
    shard_start_time = time.time()
    state = state.replace(params=parallel.shard_params(state.params))
    logger.info(f"Parameters sharded in {time.time() - shard_start_time:.2f} seconds")

    # Load checkpoint if requested
    if args.resume_from_checkpoint:
        checkpoint_start_time = time.time()
        state, step = load_checkpoint(args.resume_from_checkpoint, state)
        logger.info(f"Checkpoint loaded in {time.time() - checkpoint_start_time:.2f} seconds")

    # Load tokenizer
    tokenizer_start_time = time.time()
    tokenizer = load_tokenizer(config)
    logger.info(f"Tokenizer loaded in {time.time() - tokenizer_start_time:.2f} seconds")

    # Load dataset with optimized loading
    dataset_start_time = time.time()
    train_dataset, eval_dataset = load_dataset(config, tokenizer)
    logger.info(f"Datasets loaded in {time.time() - dataset_start_time:.2f} seconds")

    # Create data loaders with optimized configuration for TPU v4-32
    dataloader_start_time = time.time()
    train_loader, eval_loader = create_data_loaders(
        config,
        train_dataset,
        eval_dataset,
        tokenizer
    )
    logger.info(f"Data loaders created in {time.time() - dataloader_start_time:.2f} seconds")

    # Create TensorBoard summary writer
    summary_writer = create_summary_writer(
        os.path.join(config.output_dir, "tensorboard")
    )

    # Create trainer configuration with optimized settings for TPU v4-32
    trainer_config = TrainerConfig(
        model_config=config.model_config,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        warmup_steps=config.warmup_steps,
        max_steps=config.max_steps,
        batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        max_grad_norm=config.max_grad_norm,
        adam_beta1=config.adam_beta1,
        adam_beta2=config.adam_beta2,
        adam_epsilon=config.adam_epsilon,
        logging_steps=config.logging_steps,
        save_steps=config.save_steps,
        eval_steps=config.eval_steps,
        output_dir=config.output_dir,
        seed=config.seed,
        dtype=config.dtype,
        # Additional optimized settings for TPU v4-32
        use_pjit=True,  # Use pjit for better performance
        use_scan=True,  # Use scan for layer iteration
        use_remat=config.model_config.use_gradient_checkpointing,  # Use rematerialization for memory efficiency
        use_sharded_optim=True,  # Use sharded optimizer states
        profile_steps=100,  # Profile every 100 steps
        async_checkpointing=True,  # Use async checkpointing for better performance
    )

    # Create trainer
    trainer = Trainer(
        config=trainer_config,
        model=model,
        train_dataloader=train_loader,
        eval_dataloader=eval_loader,
        state=state,
        parallel=parallel,  # Pass parallelism object for optimized training
    )

    # Log total initialization time
    logger.info(f"Total initialization time: {time.time() - start_time:.2f} seconds")

    # Calculate estimated training time
    steps_per_day = 24 * 60 * 60 / (5 * 60)  # Assuming 5 minutes per 100 steps (rough estimate)
    estimated_days = config.max_steps / steps_per_day
    logger.info(f"Estimated training time: {estimated_days:.2f} days for {config.max_steps} steps")

    # Train model with performance monitoring
    try:
        train_start_time = time.time()
        trainer.train()
        train_duration = time.time() - train_start_time
        logger.info(f"Training completed in {train_duration / 3600:.2f} hours")
        logger.info(f"Average training speed: {config.max_steps / train_duration:.2f} steps/second")
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        raise


if __name__ == "__main__":
    main()