# Make sure to run the script with the following envs: # PJRT_DEVICE=TPU XLA_USE_SPMD=1 import torch import torch_xla import torch_xla.core.xla_model as xm from datasets import load_dataset from peft import LoraConfig, get_peft_model from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments from trl import SFTTrainer # Set up TPU device. device = xm.xla_device() model_id = "google/gemma-7b" # Load the pretrained model and tokenizer. tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) # Set up PEFT LoRA for fine-tuning. lora_config = LoraConfig( r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM", ) # Load the dataset and format it for training. data = load_dataset("Abirate/english_quotes", split="train") max_seq_length = 1024 # Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True. fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": [ "GemmaDecoderLayer" ], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True} # Finally, set up the trainer and train the model. trainer = SFTTrainer( model=model, train_dataset=data, args=TrainingArguments( per_device_train_batch_size=64, # This is actually the global batch size for SPMD. num_train_epochs=100, max_steps=-1, output_dir="./output", optim="adafactor", logging_steps=1, dataloader_drop_last = True, # Required for SPMD. fsdp="full_shard", fsdp_config=fsdp_config, ), peft_config=lora_config, dataset_text_field="quote", max_seq_length=max_seq_length, packing=True, ) trainer.train()