minilm / minilm-1.py
yash3056's picture
Upload folder using huggingface_hub
681c4e1 verified
import os
# Set environment variables before imports
os.environ['FI_PROVIDER'] = 'tcp'
os.environ['CCL_ATL_TRANSPORT'] = 'ofi'
from transformers import DataCollatorForLanguageModeling, LlamaForCausalLM, AutoTokenizer, AutoConfig
from datasets import load_dataset
from torch.optim import AdamW # Import AdamW from PyTorch
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import intel_extension_for_pytorch as ipex
import oneccl_bindings_for_pytorch
import os
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
# Set default values for RANK, WORLD_SIZE, MASTER_ADDR, and MASTER_PORT if not provided
os.environ.setdefault('RANK', '0')
os.environ.setdefault('WORLD_SIZE', '1')
os.environ.setdefault('MASTER_ADDR', 'localhost')
os.environ.setdefault('MASTER_PORT', '29500')
# Define model_name and tokenizer before 'train_model'
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
def setup(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='ccl')
def cleanup():
dist.destroy_process_group()
def train_model(rank, world_size):
setup(rank, world_size)
device = torch.device(f'xpu:{rank}')
torch.xpu.set_device(device)
# Initialize model and move to device
config = AutoConfig.from_pretrained(model_name)
# Core dimensions (kept as required)
config.hidden_size = 768
config.intermediate_size = 3072
# Reduce model complexity
config.num_hidden_layers = 16#40 # Halved from 16
config.num_attention_heads = 12 # Must divide hidden_size (768/12=64)
config.num_key_value_heads = 6 # Match attention heads
config.head_dim = config.hidden_size // config.num_attention_heads
# Reduce memory footprint
config.max_position_embeddings = 8196
#config.vocab_size = 32000
# Optimize other parameters
config.rope_theta = 10000.0
config.initializer_range = 0.02
config.rms_norm_eps = 1e-6
config.attention_dropout = 0.1
# Token IDs
config.bos_token_id = 31998
config.eos_token_id = 31999
config.pad_token_id = 0
# Simplified rope scaling
config.rope_scaling = {
"type": "linear",
"factor": 1.0
}
# Keep essential settings
config.tie_word_embeddings = True
config.torch_dtype = "bfloat16"
model = LlamaForCausalLM(config).to(device, dtype=torch.bfloat16)
# Wrap model with DDP after moving to device
model = DDP(model, device_ids=[rank])
# Load and tokenize dataset inside 'train_model'
dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", cache_dir="./").shuffle(seed=42)
if rank == 0:
# Print all column names
print(dataset.column_names)
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024)
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
remove_columns=['id', 'url', 'title', 'text'],
num_proc=24 # Reduce the number of processes
)
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
return_tensors="pt"
)
# Prepare DataLoader with distributed sampler
train_sampler = DistributedSampler(
tokenized_datasets,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=42
)
train_dataloader = DataLoader(
tokenized_datasets,
batch_size=8,
sampler=train_sampler,
collate_fn=data_collator,
num_workers=2, # Set num_workers to 0
# prefetch_factor=8,
persistent_workers=True,
pin_memory=True,
)
if rank == 0:
# Check the number of parameters in the model
num_params = model.module.num_parameters() / 1e6
print(f"The model has {num_params:.2f}M parameters.")
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
# Training loop
# Training loop
model.train()
num_epochs = 40
step = 0
running_loss = 0
logging_steps = 500 # Adjust as needed
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch)
epoch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}",
position=0, leave=True, disable=rank != 0)
epoch_loss = 0
num_batches = 0
for batch in epoch_iterator:
optimizer.zero_grad()
# Move batch to device
input_ids = batch['input_ids'].to(device, dtype=torch.long)
attention_mask = batch['attention_mask'].to(device, dtype=torch.bfloat16)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
# Backward pass
loss.backward()
# Gradient clipping (optional but recommended)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Aggregate loss across processes
if world_size > 1:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / world_size
# Update metrics
if rank == 0:
running_loss = 0.98 * running_loss + 0.02 * loss.item() if step > 0 else loss.item()
epoch_loss += loss.item()
num_batches += 1
# Update progress bar
epoch_iterator.set_postfix({
'loss': f'{running_loss:.4f}',
'batch_loss': f'{loss.item():.4f}'
})
# Detailed logging every N steps
if step % logging_steps == 0:
print(f"\nStep {step}")
print(f"Running loss: {running_loss:.4f}")
print(f"Batch loss: {loss.item():.4f}")
print(f"Average epoch loss: {epoch_loss/num_batches:.4f}")
step += 1
# End of epoch reporting
if rank == 0:
avg_epoch_loss = epoch_loss / num_batches
print(f"\nEpoch {epoch+1} completed.")
print(f"Average epoch loss: {avg_epoch_loss:.4f}")
# Save checkpoint
if rank == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_epoch_loss,
}
model.module.save_pretrained(f"fine-tuned-llama-8192-{epoch+1}-{step}")
tokenizer.save_pretrained(f"fine-tuned-llama-8192-{epoch+1}-{step}")
torch.save(checkpoint, f"fine-tuned-llama-8192-{epoch+1}-{step}/model.pt")
cleanup()
if __name__ == "__main__":
world_size = 16 # Number of XPUs
mp.spawn(
train_model,
args=(world_size,),
nprocs=world_size,
join=True
)