Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
def create_custom_forward(module): | |
def custom_forward(*inputs, **kwargs): | |
return module(*inputs, **kwargs) | |
return custom_forward | |
def gradient_checkpoint_forward( | |
model, | |
use_gradient_checkpointing, | |
use_gradient_checkpointing_offload, | |
*args, | |
**kwargs, | |
): | |
if use_gradient_checkpointing_offload: | |
with torch.autograd.graph.save_on_cpu(): | |
model_output = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(model), | |
*args, | |
**kwargs, | |
use_reentrant=False, | |
) | |
elif use_gradient_checkpointing: | |
model_output = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(model), | |
*args, | |
**kwargs, | |
use_reentrant=False, | |
) | |
else: | |
model_output = model(*args, **kwargs) | |
return model_output | |