winglian commited on
Commit
4b997c3
1 Parent(s): fac2d98

allow the optimizer prune ratio for ReLoRA to be configurable (#1287)

Browse files

* allow the optimizer prune ration for relora to be configurable

* update docs for relora

* prevent circular imports

README.md CHANGED
@@ -734,6 +734,8 @@ peft:
734
  # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
735
  relora_steps: # Number of steps per ReLoRA restart
736
  relora_warmup_steps: # Number of per-restart warmup steps
 
 
737
  relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
738
 
739
  # wandb configuration if you're using it
 
734
  # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
735
  relora_steps: # Number of steps per ReLoRA restart
736
  relora_warmup_steps: # Number of per-restart warmup steps
737
+ relora_anneal_steps: # Number of anneal steps for each relora cycle
738
+ relora_prune_ratio: # threshold for optimizer magnitude when pruning
739
  relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
740
 
741
  # wandb configuration if you're using it
src/axolotl/core/trainer_builder.py CHANGED
@@ -131,6 +131,10 @@ class AxolotlTrainingArguments(TrainingArguments):
131
  default=None,
132
  metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
133
  )
 
 
 
 
134
  bench_split: Optional[str] = field(
135
  default="eval", metadata={"help": "The benchmark split to run on"}
136
  )
@@ -900,9 +904,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
900
  training_arguments_kwargs[
901
  "sample_packing_seq_len_multiplier"
902
  ] = self.cfg.micro_batch_size
903
- training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
904
- training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
905
- training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
 
 
 
 
 
 
 
 
 
 
 
906
  training_arguments_kwargs = self.hook_pre_create_training_args(
907
  training_arguments_kwargs
908
  )
 
131
  default=None,
132
  metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
133
  )
134
+ relora_prune_ratio: Optional[float] = field(
135
+ default=0.9,
136
+ metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
137
+ )
138
  bench_split: Optional[str] = field(
139
  default="eval", metadata={"help": "The benchmark split to run on"}
140
  )
 
904
  training_arguments_kwargs[
905
  "sample_packing_seq_len_multiplier"
906
  ] = self.cfg.micro_batch_size
907
+ if self.cfg.relora_steps:
908
+ training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
909
+ training_arguments_kwargs[
910
+ "relora_warmup_steps"
911
+ ] = self.cfg.relora_warmup_steps
912
+ if self.cfg.relora_anneal_steps:
913
+ training_arguments_kwargs[
914
+ "relora_anneal_steps"
915
+ ] = self.cfg.relora_anneal_steps
916
+ if self.cfg.relora_prune_ratio:
917
+ training_arguments_kwargs[
918
+ "relora_prune_ratio"
919
+ ] = self.cfg.relora_prune_ratio
920
+
921
  training_arguments_kwargs = self.hook_pre_create_training_args(
922
  training_arguments_kwargs
923
  )
src/axolotl/monkeypatch/relora.py CHANGED
@@ -46,8 +46,9 @@ def reset_optimizer(
46
  *,
47
  reset_params: list[str], # where str is the key to a torch.nn.Parameter
48
  optimizer_state_keys: list[str],
 
49
  ):
50
- pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
51
  n_zeros = 0
52
  n_total = 0
53
 
@@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback):
159
  optimizer,
160
  reset_params=lora_params,
161
  optimizer_state_keys=optimizer_state_keys,
 
162
  )
163
 
164
  if self.quantized:
 
46
  *,
47
  reset_params: list[str], # where str is the key to a torch.nn.Parameter
48
  optimizer_state_keys: list[str],
49
+ prune_ratio: float = 0.9,
50
  ):
51
+ pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
52
  n_zeros = 0
53
  n_total = 0
54
 
 
160
  optimizer,
161
  reset_params=lora_params,
162
  optimizer_state_keys=optimizer_state_keys,
163
+ prune_ratio=args.relora_prune_ratio,
164
  )
165
 
166
  if self.quantized: