add gradient checkpointing for the final_layernorm module.
#77
by
zhaoqf123
- opened
Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer, i.e., layer-27's LoRA weights won't be updated!
For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.
class ParamsTensorBoardCallback(TensorBoardCallback):
def __init__(self, tb_writer=None, params=None, process_name=lambda x:x):
super().__init__(tb_writer)
self.params = params
self._process_name = process_name
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
dict_ = {}
model = kwargs["model"]
for name in self.params:
param = model.get_parameter(name)
param = param.flatten()
name_p = self._process_name(name)
dict_tmp = {
f"{name_p}_mean": param.mean().item(),
f"{name_p}_max": param.max().item(),
f"{name_p}_q75": param.quantile(0.75).item(),
f"{name_p}_q25": param.quantile(0.25).item(),
f"{name_p}_min": param.min().item(),
f"{name_p}_median": param.median().item(),
f"{name_p}_std": param.std().item(),
}
dict_.update(dict_tmp)
self.on_log(args, state, control, logs=dict_, **kwargs)
def get_params_for_logging(model):
ls_params = []
for name, param in model.named_parameters():
if param.requires_grad:
ls_params.append(name)
return ls_params
ls_params = get_params_for_logging(model)
tb_cb = ParamsTensorBoardCallback(
None, ls_params, process_name=lambda x: x[36:]
)
trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=args,
data_collator=data_collator,
callbacks=[tb_cb]
)
I have made a similar PR for llama model in transformer repo.