| import torch | |
| class _WrapperModule(torch.nn.Module): | |
| def __init__(self, f): # type: ignore[no-untyped-def] | |
| super().__init__() | |
| self.f = f | |
| def forward(self, *args, **kwargs): # type: ignore[no-untyped-def] | |
| return self.f(*args, **kwargs) | |
| import torch | |
| class _WrapperModule(torch.nn.Module): | |
| def __init__(self, f): # type: ignore[no-untyped-def] | |
| super().__init__() | |
| self.f = f | |
| def forward(self, *args, **kwargs): # type: ignore[no-untyped-def] | |
| return self.f(*args, **kwargs) | |