cleanup, prep for 4bit quant support
Browse files- README.md +21 -1
- scripts/finetune.py +18 -6
- setup.cfg +3 -0
README.md
CHANGED
@@ -30,4 +30,24 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
|
30 |
|
31 |
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
32 |
- Install python dependencies `pip3 install -r requirements.txt`
|
33 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
32 |
- Install python dependencies `pip3 install -r requirements.txt`
|
33 |
+
- Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
34 |
+
|
35 |
+
```yaml
|
36 |
+
compute_environment: LOCAL_MACHINE
|
37 |
+
distributed_type: MULTI_GPU
|
38 |
+
downcast_bf16: 'no'
|
39 |
+
gpu_ids: all
|
40 |
+
machine_rank: 0
|
41 |
+
main_training_function: main
|
42 |
+
mixed_precision: bf16
|
43 |
+
num_machines: 1
|
44 |
+
num_processes: 4
|
45 |
+
rdzv_backend: static
|
46 |
+
same_network: true
|
47 |
+
tpu_env: []
|
48 |
+
tpu_use_cluster: false
|
49 |
+
tpu_use_sudo: false
|
50 |
+
use_cpu: false
|
51 |
+
```
|
52 |
+
|
53 |
+
- Train! `accelerate launch scripts/finetune.py`, make sure to choose the correct YAML config file
|
scripts/finetune.py
CHANGED
@@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|
68 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
69 |
replace_llama_attn_with_flash_attn()
|
70 |
|
|
|
71 |
try:
|
72 |
if "llama" in base_model:
|
73 |
model = LlamaForCausalLM.from_pretrained(
|
74 |
base_model,
|
75 |
load_in_8bit=cfg.load_in_8bit,
|
76 |
-
torch_dtype=
|
77 |
device_map=cfg.device_map,
|
78 |
)
|
79 |
else:
|
80 |
model = getattr(transformers, model_type).from_pretrained(
|
81 |
base_model,
|
82 |
load_in_8bit=cfg.load_in_8bit,
|
83 |
-
torch_dtype=
|
84 |
device_map=cfg.device_map,
|
85 |
)
|
86 |
except:
|
87 |
model = AutoModelForCausalLM.from_pretrained(
|
88 |
base_model,
|
89 |
load_in_8bit=cfg.load_in_8bit,
|
90 |
-
torch_dtype=
|
91 |
device_map=cfg.device_map,
|
92 |
)
|
93 |
|
@@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
235 |
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
236 |
|
237 |
training_arguments_kwargs = {}
|
238 |
-
|
|
|
|
|
|
|
239 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
240 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
241 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
@@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
256 |
group_by_length=cfg.group_by_length,
|
257 |
report_to="wandb" if cfg.use_wandb else None,
|
258 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
|
|
259 |
**training_arguments_kwargs,
|
260 |
)
|
261 |
|
262 |
-
trainer_kwargs = {}
|
263 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
264 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
265 |
optimizer_grouped_parameters = [
|
@@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
282 |
lr=training_args.learning_rate,
|
283 |
)
|
284 |
|
|
|
285 |
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
286 |
adam_bnb_optim,
|
287 |
training_args.warmup_steps,
|
288 |
total_num_steps,
|
289 |
)
|
290 |
-
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
291 |
|
|
|
292 |
if cfg.early_stopping_patience:
|
293 |
early_stop_cb = EarlyStoppingCallback(
|
294 |
cfg.early_stopping_patience,
|
@@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
300 |
train_dataset=train_dataset,
|
301 |
eval_dataset=eval_dataset,
|
302 |
args=training_args,
|
|
|
303 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
304 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
305 |
),
|
@@ -342,6 +348,12 @@ def train(
|
|
342 |
cfg.gradient_accumulation_steps // cfg.world_size
|
343 |
)
|
344 |
setup_wandb_env_vars(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
# Load the model and tokenizer
|
347 |
model, tokenizer, lora_config = load_model(
|
|
|
68 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
69 |
replace_llama_attn_with_flash_attn()
|
70 |
|
71 |
+
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
72 |
try:
|
73 |
if "llama" in base_model:
|
74 |
model = LlamaForCausalLM.from_pretrained(
|
75 |
base_model,
|
76 |
load_in_8bit=cfg.load_in_8bit,
|
77 |
+
torch_dtype=torch_dtype,
|
78 |
device_map=cfg.device_map,
|
79 |
)
|
80 |
else:
|
81 |
model = getattr(transformers, model_type).from_pretrained(
|
82 |
base_model,
|
83 |
load_in_8bit=cfg.load_in_8bit,
|
84 |
+
torch_dtype=torch_dtype,
|
85 |
device_map=cfg.device_map,
|
86 |
)
|
87 |
except:
|
88 |
model = AutoModelForCausalLM.from_pretrained(
|
89 |
base_model,
|
90 |
load_in_8bit=cfg.load_in_8bit,
|
91 |
+
torch_dtype=torch_dtype,
|
92 |
device_map=cfg.device_map,
|
93 |
)
|
94 |
|
|
|
236 |
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
237 |
|
238 |
training_arguments_kwargs = {}
|
239 |
+
if cfg.bf16 == "full":
|
240 |
+
training_arguments_kwargs["bf16_full_eval"] = True
|
241 |
+
else:
|
242 |
+
training_arguments_kwargs["bf16"] = cfg.bf16
|
243 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
244 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
245 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
|
260 |
group_by_length=cfg.group_by_length,
|
261 |
report_to="wandb" if cfg.use_wandb else None,
|
262 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
263 |
+
gradient_checkpointing=cfg.gradient_checkpointing,
|
264 |
**training_arguments_kwargs,
|
265 |
)
|
266 |
|
|
|
267 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
268 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
269 |
optimizer_grouped_parameters = [
|
|
|
286 |
lr=training_args.learning_rate,
|
287 |
)
|
288 |
|
289 |
+
# TODO optionally use torch.optim.OneCycleLR
|
290 |
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
291 |
adam_bnb_optim,
|
292 |
training_args.warmup_steps,
|
293 |
total_num_steps,
|
294 |
)
|
|
|
295 |
|
296 |
+
trainer_kwargs = {}
|
297 |
if cfg.early_stopping_patience:
|
298 |
early_stop_cb = EarlyStoppingCallback(
|
299 |
cfg.early_stopping_patience,
|
|
|
305 |
train_dataset=train_dataset,
|
306 |
eval_dataset=eval_dataset,
|
307 |
args=training_args,
|
308 |
+
optimizers=(adam_bnb_optim, lr_scheduler),
|
309 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
310 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
311 |
),
|
|
|
348 |
cfg.gradient_accumulation_steps // cfg.world_size
|
349 |
)
|
350 |
setup_wandb_env_vars(cfg)
|
351 |
+
if cfg.device == "mps":
|
352 |
+
cfg.load_in_8bit = False
|
353 |
+
cfg.tf32 = False
|
354 |
+
if cfg.bf16:
|
355 |
+
cfg.fp16 = True
|
356 |
+
cfg.bf16 = False
|
357 |
|
358 |
# Load the model and tokenizer
|
359 |
model, tokenizer, lora_config = load_model(
|
setup.cfg
CHANGED
@@ -28,3 +28,6 @@ install_requires =
|
|
28 |
[options.packages.find]
|
29 |
where = src
|
30 |
|
|
|
|
|
|
|
|
28 |
[options.packages.find]
|
29 |
where = src
|
30 |
|
31 |
+
[options.extras_require]
|
32 |
+
gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
|
33 |
+
gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]
|