Spaces:
Runtime error
Runtime error
File size: 853 Bytes
e56055d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from utils import parent_module
def linear_backward_hook(mod, grad_in, grad_out):
if not hasattr(mod, "weight"):
print(f"{mod} has no weight!")
return
if hasattr(mod.weight, "__x__"):
assert len(grad_out) == 1
# mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2)
mod.weight.__delta__ = grad_out[0].detach()
else:
print(f"{mod} has no __x__")
def linear_forward_hook(mod, activations, output):
assert len(activations) == 1
mod.weight.__x__ = activations[0].detach()
def hook_model(model, pnames):
handles = []
for m in [parent_module(model, pname) for pname in pnames]:
handles.append(m.register_full_backward_hook(linear_backward_hook))
handles.append(m.register_forward_hook(linear_forward_hook))
model.handles = handles
|