JSX_TTS / torch /fx /passes /fake_tensor_prop.py
UMMJ's picture
Upload 5875 files
9dd3461
import torch.fx
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode
__all__ = ['FakeTensorProp']
@compatibility(is_backward_compatible=False)
class FakeTensorProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and record a fake tensor representing
the metadata for the node. Unlike ShapeProp, (1) this propagation
is cheap--it does the propagation with meta tensors which do not actually
store data, and (2) the fake tensors have much more fine grained information,
e.g., they have accurate alias information that can be consulted by looking
at the storages.
Args:
module (GraphModule): The module to be executed
"""
def run_node(self, n: Node):
result = super().run_node(n)
n.meta['val'] = result
return result
def propagate(self, *args):
with FakeTensorMode.push() as mode:
fake_args = [mode.from_tensor(a) for a in args]
return super().run(*fake_args)