|
import os |
|
|
|
os.environ['FI_PROVIDER'] = 'tcp' |
|
os.environ['CCL_ATL_TRANSPORT'] = 'ofi' |
|
|
|
from transformers import DataCollatorForLanguageModeling, LlamaForCausalLM, AutoTokenizer, AutoConfig |
|
from datasets import load_dataset |
|
from torch.optim import AdamW |
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import intel_extension_for_pytorch as ipex |
|
import oneccl_bindings_for_pytorch |
|
import os |
|
import torch.multiprocessing as mp |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
from tqdm import tqdm |
|
|
|
|
|
os.environ.setdefault('RANK', '0') |
|
os.environ.setdefault('WORLD_SIZE', '1') |
|
os.environ.setdefault('MASTER_ADDR', 'localhost') |
|
os.environ.setdefault('MASTER_PORT', '29500') |
|
|
|
|
|
model_name = "meta-llama/Llama-3.2-1B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
def setup(rank, world_size): |
|
os.environ['RANK'] = str(rank) |
|
os.environ['WORLD_SIZE'] = str(world_size) |
|
os.environ['MASTER_ADDR'] = 'localhost' |
|
os.environ['MASTER_PORT'] = '29500' |
|
dist.init_process_group(backend='ccl') |
|
|
|
def cleanup(): |
|
dist.destroy_process_group() |
|
|
|
def train_model(rank, world_size): |
|
setup(rank, world_size) |
|
device = torch.device(f'xpu:{rank}') |
|
torch.xpu.set_device(device) |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_name) |
|
|
|
|
|
config.hidden_size = 768 |
|
config.intermediate_size = 3072 |
|
|
|
|
|
config.num_hidden_layers = 16 |
|
config.num_attention_heads = 12 |
|
config.num_key_value_heads = 6 |
|
config.head_dim = config.hidden_size // config.num_attention_heads |
|
|
|
|
|
config.max_position_embeddings = 8196 |
|
|
|
|
|
|
|
config.rope_theta = 10000.0 |
|
config.initializer_range = 0.02 |
|
config.rms_norm_eps = 1e-6 |
|
config.attention_dropout = 0.1 |
|
|
|
|
|
config.bos_token_id = 31998 |
|
config.eos_token_id = 31999 |
|
config.pad_token_id = 0 |
|
|
|
|
|
config.rope_scaling = { |
|
"type": "linear", |
|
"factor": 1.0 |
|
} |
|
|
|
|
|
config.tie_word_embeddings = True |
|
config.torch_dtype = "bfloat16" |
|
|
|
model = LlamaForCausalLM(config).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
model = DDP(model, device_ids=[rank]) |
|
|
|
|
|
dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", cache_dir="./").shuffle(seed=42) |
|
|
|
if rank == 0: |
|
|
|
print(dataset.column_names) |
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024) |
|
|
|
tokenized_datasets = dataset.map( |
|
tokenize_function, |
|
batched=True, |
|
remove_columns=['id', 'url', 'title', 'text'], |
|
num_proc=24 |
|
) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, |
|
mlm=False, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
train_sampler = DistributedSampler( |
|
tokenized_datasets, |
|
num_replicas=world_size, |
|
rank=rank, |
|
shuffle=True, |
|
seed=42 |
|
) |
|
train_dataloader = DataLoader( |
|
tokenized_datasets, |
|
batch_size=8, |
|
sampler=train_sampler, |
|
collate_fn=data_collator, |
|
num_workers=2, |
|
|
|
persistent_workers=True, |
|
pin_memory=True, |
|
) |
|
|
|
if rank == 0: |
|
|
|
num_params = model.module.num_parameters() / 1e6 |
|
print(f"The model has {num_params:.2f}M parameters.") |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01) |
|
|
|
|
|
|
|
model.train() |
|
num_epochs = 40 |
|
step = 0 |
|
running_loss = 0 |
|
logging_steps = 500 |
|
|
|
for epoch in range(num_epochs): |
|
train_sampler.set_epoch(epoch) |
|
epoch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", |
|
position=0, leave=True, disable=rank != 0) |
|
|
|
epoch_loss = 0 |
|
num_batches = 0 |
|
|
|
for batch in epoch_iterator: |
|
optimizer.zero_grad() |
|
|
|
|
|
input_ids = batch['input_ids'].to(device, dtype=torch.long) |
|
attention_mask = batch['attention_mask'].to(device, dtype=torch.bfloat16) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) |
|
loss = outputs.loss |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
optimizer.step() |
|
|
|
|
|
if world_size > 1: |
|
dist.all_reduce(loss, op=dist.ReduceOp.SUM) |
|
loss = loss / world_size |
|
|
|
|
|
if rank == 0: |
|
running_loss = 0.98 * running_loss + 0.02 * loss.item() if step > 0 else loss.item() |
|
epoch_loss += loss.item() |
|
num_batches += 1 |
|
|
|
|
|
epoch_iterator.set_postfix({ |
|
'loss': f'{running_loss:.4f}', |
|
'batch_loss': f'{loss.item():.4f}' |
|
}) |
|
|
|
|
|
if step % logging_steps == 0: |
|
print(f"\nStep {step}") |
|
print(f"Running loss: {running_loss:.4f}") |
|
print(f"Batch loss: {loss.item():.4f}") |
|
print(f"Average epoch loss: {epoch_loss/num_batches:.4f}") |
|
|
|
step += 1 |
|
|
|
|
|
if rank == 0: |
|
avg_epoch_loss = epoch_loss / num_batches |
|
print(f"\nEpoch {epoch+1} completed.") |
|
print(f"Average epoch loss: {avg_epoch_loss:.4f}") |
|
|
|
|
|
if rank == 0: |
|
checkpoint = { |
|
'epoch': epoch, |
|
'model_state_dict': model.module.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': avg_epoch_loss, |
|
} |
|
model.module.save_pretrained(f"fine-tuned-llama-8192-{epoch+1}-{step}") |
|
tokenizer.save_pretrained(f"fine-tuned-llama-8192-{epoch+1}-{step}") |
|
torch.save(checkpoint, f"fine-tuned-llama-8192-{epoch+1}-{step}/model.pt") |
|
|
|
|
|
cleanup() |
|
|
|
if __name__ == "__main__": |
|
world_size = 16 |
|
mp.spawn( |
|
train_model, |
|
args=(world_size,), |
|
nprocs=world_size, |
|
join=True |
|
) |