|
import torch.fx |
|
from torch.fx import Node |
|
from torch.fx._compatibility import compatibility |
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
|
|
__all__ = ['FakeTensorProp'] |
|
|
|
@compatibility(is_backward_compatible=False) |
|
class FakeTensorProp(torch.fx.Interpreter): |
|
""" |
|
Execute an FX graph Node-by-Node and record a fake tensor representing |
|
the metadata for the node. Unlike ShapeProp, (1) this propagation |
|
is cheap--it does the propagation with meta tensors which do not actually |
|
store data, and (2) the fake tensors have much more fine grained information, |
|
e.g., they have accurate alias information that can be consulted by looking |
|
at the storages. |
|
|
|
Args: |
|
module (GraphModule): The module to be executed |
|
""" |
|
|
|
def run_node(self, n: Node): |
|
result = super().run_node(n) |
|
n.meta['val'] = result |
|
return result |
|
|
|
def propagate(self, *args): |
|
with FakeTensorMode.push() as mode: |
|
fake_args = [mode.from_tensor(a) for a in args] |
|
return super().run(*fake_args) |
|
|