import torch | |
import contextlib | |
def use_patched_ops(operations): | |
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] | |
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} | |
try: | |
for op_name in op_names: | |
setattr(torch.nn, op_name, getattr(operations, op_name)) | |
yield | |
finally: | |
for op_name in op_names: | |
setattr(torch.nn, op_name, backups[op_name]) | |
return | |