import os # Set environment variables before imports os.environ['FI_PROVIDER'] = 'tcp' os.environ['CCL_ATL_TRANSPORT'] = 'ofi' from transformers import DataCollatorForLanguageModeling, LlamaForCausalLM, AutoTokenizer, AutoConfig from datasets import load_dataset from torch.optim import AdamW # Import AdamW from PyTorch import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import intel_extension_for_pytorch as ipex import oneccl_bindings_for_pytorch import os import torch.multiprocessing as mp from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm # Set default values for RANK, WORLD_SIZE, MASTER_ADDR, and MASTER_PORT if not provided os.environ.setdefault('RANK', '0') os.environ.setdefault('WORLD_SIZE', '1') os.environ.setdefault('MASTER_ADDR', 'localhost') os.environ.setdefault('MASTER_PORT', '29500') # Define model_name and tokenizer before 'train_model' model_name = "meta-llama/Llama-3.2-1B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token def setup(rank, world_size): os.environ['RANK'] = str(rank) os.environ['WORLD_SIZE'] = str(world_size) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(backend='ccl') def cleanup(): dist.destroy_process_group() def train_model(rank, world_size): setup(rank, world_size) device = torch.device(f'xpu:{rank}') torch.xpu.set_device(device) # Initialize model and move to device config = AutoConfig.from_pretrained(model_name) # Core dimensions (kept as required) config.hidden_size = 768 config.intermediate_size = 3072 # Reduce model complexity config.num_hidden_layers = 16#40 # Halved from 16 config.num_attention_heads = 12 # Must divide hidden_size (768/12=64) config.num_key_value_heads = 6 # Match attention heads config.head_dim = config.hidden_size // config.num_attention_heads # Reduce memory footprint config.max_position_embeddings = 8196 #config.vocab_size = 32000 # Optimize other parameters config.rope_theta = 10000.0 config.initializer_range = 0.02 config.rms_norm_eps = 1e-6 config.attention_dropout = 0.1 # Token IDs config.bos_token_id = 31998 config.eos_token_id = 31999 config.pad_token_id = 0 # Simplified rope scaling config.rope_scaling = { "type": "linear", "factor": 1.0 } # Keep essential settings config.tie_word_embeddings = True config.torch_dtype = "bfloat16" model = LlamaForCausalLM(config).to(device, dtype=torch.bfloat16) # Wrap model with DDP after moving to device model = DDP(model, device_ids=[rank]) # Load and tokenize dataset inside 'train_model' dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", cache_dir="./").shuffle(seed=42) if rank == 0: # Print all column names print(dataset.column_names) def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024) tokenized_datasets = dataset.map( tokenize_function, batched=True, remove_columns=['id', 'url', 'title', 'text'], num_proc=24 # Reduce the number of processes ) # Data collator for language modeling data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, return_tensors="pt" ) # Prepare DataLoader with distributed sampler train_sampler = DistributedSampler( tokenized_datasets, num_replicas=world_size, rank=rank, shuffle=True, seed=42 ) train_dataloader = DataLoader( tokenized_datasets, batch_size=8, sampler=train_sampler, collate_fn=data_collator, num_workers=2, # Set num_workers to 0 # prefetch_factor=8, persistent_workers=True, pin_memory=True, ) if rank == 0: # Check the number of parameters in the model num_params = model.module.num_parameters() / 1e6 print(f"The model has {num_params:.2f}M parameters.") # Initialize optimizer optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01) # Training loop # Training loop model.train() num_epochs = 40 step = 0 running_loss = 0 logging_steps = 500 # Adjust as needed for epoch in range(num_epochs): train_sampler.set_epoch(epoch) epoch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", position=0, leave=True, disable=rank != 0) epoch_loss = 0 num_batches = 0 for batch in epoch_iterator: optimizer.zero_grad() # Move batch to device input_ids = batch['input_ids'].to(device, dtype=torch.long) attention_mask = batch['attention_mask'].to(device, dtype=torch.bfloat16) # Forward pass outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss # Backward pass loss.backward() # Gradient clipping (optional but recommended) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # Aggregate loss across processes if world_size > 1: dist.all_reduce(loss, op=dist.ReduceOp.SUM) loss = loss / world_size # Update metrics if rank == 0: running_loss = 0.98 * running_loss + 0.02 * loss.item() if step > 0 else loss.item() epoch_loss += loss.item() num_batches += 1 # Update progress bar epoch_iterator.set_postfix({ 'loss': f'{running_loss:.4f}', 'batch_loss': f'{loss.item():.4f}' }) # Detailed logging every N steps if step % logging_steps == 0: print(f"\nStep {step}") print(f"Running loss: {running_loss:.4f}") print(f"Batch loss: {loss.item():.4f}") print(f"Average epoch loss: {epoch_loss/num_batches:.4f}") step += 1 # End of epoch reporting if rank == 0: avg_epoch_loss = epoch_loss / num_batches print(f"\nEpoch {epoch+1} completed.") print(f"Average epoch loss: {avg_epoch_loss:.4f}") # Save checkpoint if rank == 0: checkpoint = { 'epoch': epoch, 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_epoch_loss, } model.module.save_pretrained(f"fine-tuned-llama-8192-{epoch+1}-{step}") tokenizer.save_pretrained(f"fine-tuned-llama-8192-{epoch+1}-{step}") torch.save(checkpoint, f"fine-tuned-llama-8192-{epoch+1}-{step}/model.pt") cleanup() if __name__ == "__main__": world_size = 16 # Number of XPUs mp.spawn( train_model, args=(world_size,), nprocs=world_size, join=True )