Supporting gradient checkpointing for QLORA

#16
by ospanbatyr - opened

Hi everyone,

While trying to finetune OLMo-7B with QLORA, OLMoForCausalLM does not support gradient checkpointing error is thrown in the prepare_model_for_kbit_training(model) line. Traceback:

Traceback (most recent call last):
  File "/scratch/users/oince22/hpc_run/CartographyFT/src/driver.py", line 39, in
main
    run_main(P, logger)
  File "/scratch/users/oince22/hpc_run/CartographyFT/src/driver.py", line 68, in
run_main
    llm, tokenizer = P.get_lm()
                     ^^^^^^^^^^
  File "/scratch/users/oince22/hpc_run/CartographyFT/src/params.py", line 311, 
in get_lm
    model = prepare_model_for_kbit_training(model)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File 
"/kuacc/users/oince22/.conda/envs/icl/lib/python3.11/site-packages/peft/utils/ot
her.py", line 139, in prepare_model_for_kbit_training
    model.gradient_checkpointing_enable(**gc_enable_kwargs)
  File 
"/kuacc/users/oince22/.conda/envs/icl/lib/python3.11/site-packages/transformers/
modeling_utils.py", line 2092, in gradient_checkpointing_enable
    raise ValueError(f"{self.__class__.__name__} does not support gradient 
checkpointing.")
ValueError: OLMoForCausalLM does not support gradient checkpointing.

Sign up or log in to comment