_set_gradient_checkpointing() got an unexpected keyword argument 'enable'

#3
by ehartford - opened

I have worked around this by modifying modeling_qwen.py as follows:

def _set_gradient_checkpointing(self, enable: bool = False, gradient_checkpointing_func: Callable = None):
        is_gradient_checkpointing_set = False

        if isinstance(self, QWenModel):
            self.gradient_checkpointing = enable
            self._gradient_checkpointing_func = gradient_checkpointing_func
            is_gradient_checkpointing_set = True

        for module in self.modules():
            if isinstance(module, QWenModel):
                module.gradient_checkpointing = enable
                module._gradient_checkpointing_func = gradient_checkpointing_func
                is_gradient_checkpointing_set = True

        if not is_gradient_checkpointing_set:
            raise ValueError(f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute 'gradient_checkpointing' to modules of the model that uses checkpointing.")

@ehartford
Hello!
I am not creator of this model,
But I solved this problem, so I want to share my solution.

My solution is check the your transformers module version, such that pip install transformers==4.34.0

Thank you!

That's not a solution when you are using software that requires latest transformers

Sign up or log in to comment