Spaces:
Running
Running
"""Functionality for Python <-> C++ frontend inter-op.""" | |
from torch import nn | |
class OrderedDictWrapper: | |
"""A wrapper around a C++ OrderedDict. | |
It dynamically evaluates the OrderedDict getter on a bound C++ module, such | |
that new changes on the C++ side are picked up. Otherwise accessing e.g. | |
``cpp_module._parameters`` just once would get a frozen copy of the parameters | |
at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` | |
so using properties does not work. | |
""" | |
def __init__(self, cpp_module, attr): | |
self.cpp_module = cpp_module | |
self.attr = attr | |
def cpp_dict(self): | |
return getattr(self.cpp_module, self.attr) | |
# Magic methods cannot be assigned dynamically and bypass ``getattr``, so we | |
# must manually override them. | |
def items(self): | |
return self.cpp_dict.items() | |
def keys(self): | |
return self.cpp_dict.keys() | |
def values(self): | |
return self.cpp_dict.values() | |
def __iter__(self): | |
return self.cpp_dict.__iter__() | |
def __len__(self): | |
return self.cpp_dict.__len__() | |
def __contains__(self, key): | |
return self.cpp_dict.__contains__(key) | |
def __getitem__(self, key): | |
return self.cpp_dict.__getitem__(key) | |
class ModuleWrapper(nn.Module): | |
"""A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" | |
def __init__(self, cpp_module): | |
# Assign before the super class constructor so ``self.training`` can be | |
# assigned to in the super class constructor. | |
self.cpp_module = cpp_module | |
super().__init__() | |
self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] | |
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] | |
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] | |
for attr in dir(cpp_module): | |
# Skip magic methods and the three attributes above. | |
if not attr.startswith("_"): | |
setattr(self, attr, getattr(self.cpp_module, attr)) | |
def _apply(self, fn, recurse=True): | |
for param in self.parameters(): | |
# Tensors stored in modules are graph leaves, and we don't | |
# want to create copy nodes, so we have to unpack the data. | |
param.data = fn(param.data) | |
if param._grad is not None: | |
param._grad.data = fn(param._grad.data) | |
for buf in self.buffers(): | |
buf.data = fn(buf.data) | |
return self | |
# nn.Module defines training as a boolean | |
# type: ignore[override] | |
def training(self): | |
return self.cpp_module.training | |
def training(self, mode): | |
self.cpp_module.train(mode) | |
def __repr__(self): | |
return self.cpp_module.__repr__() | |