Update LoRA fine-tune example - more target_modules, lower LR, bf16

#49
by michael-go - opened
Files changed (1) hide show
  1. README.md +21 -14
README.md CHANGED
@@ -96,31 +96,40 @@ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
96
  </details>
97
 
98
  ### Fine-tuning example
99
- Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library:
100
 
101
  ```python
 
102
  from datasets import load_dataset
103
- from trl import SFTTrainer
104
  from peft import LoraConfig
105
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
106
 
107
  tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
108
- model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map='auto')
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  dataset = load_dataset("Abirate/english_quotes", split="train")
111
- training_args = TrainingArguments(
112
  output_dir="./results",
113
- num_train_epochs=3,
114
  per_device_train_batch_size=4,
115
  logging_dir='./logs',
116
  logging_steps=10,
117
- learning_rate=2e-3
118
- )
119
- lora_config = LoraConfig(
120
- r=8,
121
- target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
122
- task_type="CAUSAL_LM",
123
- bias="none"
124
  )
125
  trainer = SFTTrainer(
126
  model=model,
@@ -128,9 +137,7 @@ trainer = SFTTrainer(
128
  args=training_args,
129
  peft_config=lora_config,
130
  train_dataset=dataset,
131
- dataset_text_field="quote",
132
  )
133
-
134
  trainer.train()
135
  ```
136
 
 
96
  </details>
97
 
98
  ### Fine-tuning example
99
+ Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library (requires ~120GB GPU RAM, in example 2xA100 80GB):
100
 
101
  ```python
102
+ import torch
103
  from datasets import load_dataset
104
+ from trl import SFTTrainer, SFTConfig
105
  from peft import LoraConfig
106
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
107
 
108
  tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
109
+ model = AutoModelForCausalLM.from_pretrained(
110
+ "ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
111
+
112
+ lora_config = LoraConfig(
113
+ r=8,
114
+ target_modules=[
115
+ "embed_tokens",
116
+ "x_proj", "in_proj", "out_proj", # mamba
117
+ "gate_proj", "up_proj", "down_proj", # mlp
118
+ "q_proj", "k_proj", "v_proj" # attention
119
+ ],
120
+ task_type="CAUSAL_LM",
121
+ bias="none"
122
+ )
123
 
124
  dataset = load_dataset("Abirate/english_quotes", split="train")
125
+ training_args = SFTConfig(
126
  output_dir="./results",
127
+ num_train_epochs=2,
128
  per_device_train_batch_size=4,
129
  logging_dir='./logs',
130
  logging_steps=10,
131
+ learning_rate=1e-5,
132
+ dataset_text_field="quote",
 
 
 
 
 
133
  )
134
  trainer = SFTTrainer(
135
  model=model,
 
137
  args=training_args,
138
  peft_config=lora_config,
139
  train_dataset=dataset,
 
140
  )
 
141
  trainer.train()
142
  ```
143