milandean commited on
Commit
8490370
1 Parent(s): e4fc4f7

Upload example_fsdp.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_fsdp.py +62 -0
example_fsdp.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Make sure to run the script with the following envs:
2
+ # PJRT_DEVICE=TPU XLA_USE_SPMD=1
3
+
4
+ import torch
5
+ import torch_xla
6
+
7
+ import torch_xla.core.xla_model as xm
8
+
9
+ from datasets import load_dataset
10
+ from peft import LoraConfig, get_peft_model
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
12
+ from trl import SFTTrainer
13
+
14
+ # Set up TPU device.
15
+ device = xm.xla_device()
16
+ model_id = "google/gemma-7b"
17
+
18
+ # Load the pretrained model and tokenizer.
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
21
+
22
+ # Set up PEFT LoRA for fine-tuning.
23
+ lora_config = LoraConfig(
24
+ r=8,
25
+ target_modules=["k_proj", "v_proj"],
26
+ task_type="CAUSAL_LM",
27
+ )
28
+
29
+ # Load the dataset and format it for training.
30
+ data = load_dataset("Abirate/english_quotes", split="train")
31
+ max_seq_length = 1024
32
+
33
+ # Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
34
+ fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": [
35
+ "GemmaDecoderLayer"
36
+ ],
37
+ "xla": True,
38
+ "xla_fsdp_v2": True,
39
+ "xla_fsdp_grad_ckpt": True}
40
+
41
+ # Finally, set up the trainer and train the model.
42
+ trainer = SFTTrainer(
43
+ model=model,
44
+ train_dataset=data,
45
+ args=TrainingArguments(
46
+ per_device_train_batch_size=64, # This is actually the global batch size for SPMD.
47
+ num_train_epochs=100,
48
+ max_steps=-1,
49
+ output_dir="./output",
50
+ optim="adafactor",
51
+ logging_steps=1,
52
+ dataloader_drop_last = True, # Required for SPMD.
53
+ fsdp="full_shard",
54
+ fsdp_config=fsdp_config,
55
+ ),
56
+ peft_config=lora_config,
57
+ dataset_text_field="quote",
58
+ max_seq_length=max_seq_length,
59
+ packing=True,
60
+ )
61
+
62
+ trainer.train()