File size: 31,292 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 |
import torch
import torch.nn as nn
import torch.overrides
from torch.nn.modules.module import _addindent
from torch.package import PackageImporter, PackageExporter
import linecache
from typing import Type, Dict, List, Any, Union, Optional, Set
from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
from ._compatibility import compatibility
from torch.package import Importer, sys_importer
import copy
import itertools
import sys
import traceback
from pathlib import Path
import os
import warnings
# Normal exec loses the source code, however we can work with
# the linecache module to recover it.
# Using _exec_with_source will add it to our local cache
# and then tools like TorchScript will be able to get source info.
class _EvalCacheLoader(object):
def __init__(self):
self.eval_cache = {}
self.next_id = 0
def cache(self, src: str, globals: Dict[str, Any]):
"""Store the source in a private cache, and add a lazy entry in linecache
that allows the source to be retrieved by 'filename'.
Args:
src (str): The module source to cache
globals (dict): The module globals
Returns:
str: The cache key (and dummy filename) generated for src.
"""
key = self._get_key()
self.eval_cache[key] = src
# Don't mutate globals so that this loader is only used
# to populate linecache, and doesn't interact with other modules
# that might check `__loader__`
globals_copy = globals.copy()
globals_copy['__file__'] = key
globals_copy['__name__'] = key
globals_copy['__loader__'] = self
linecache.lazycache(key, globals_copy)
return key
# Part of the loader protocol (PEP 302)
# linecache will use this method when trying to find source code
def get_source(self, module_name) -> Optional[str]:
if module_name in self.eval_cache:
return self.eval_cache[module_name]
return None
def _get_key(self):
key = f'<eval_with_key>.{self.next_id}'
self.next_id += 1
return key
_loader = _EvalCacheLoader()
def _exec_with_source(src: str, globals: Dict[str, Any]):
key = _loader.cache(src, globals)
exec(compile(src, key, 'exec'), globals)
def _forward_from_src(src: str, globals: Dict[str, Any]):
# avoid mutating the passed in dict
globals_copy = globals.copy()
_exec_with_source(src, globals_copy)
forward_fn = globals_copy['forward']
del globals_copy['forward']
return forward_fn
def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
if name in _custom_builtins:
return _custom_builtins[name].import_str
if _is_from_torch(name):
return 'import torch'
module_name, attr_name = importer.get_name(obj)
return f'from {module_name} import {attr_name} as {name}'
def _format_import_block(globals: Dict[str, Any], importer: Importer):
import_strs: Set[str] = set()
for name, obj in globals.items():
import_strs.add(_format_import_statement(name, obj, importer))
return '\n'.join(import_strs)
@compatibility(is_backward_compatible=True)
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
# BC: attribute name was changed from `code` to `_code` to facilitate
# making `code` into a property and adding a docstring to it
fn_src = body.get('_code') or body['code']
forward = _forward_from_src(import_block + fn_src, {})
return _deserialize_graph_module(forward, body)
@compatibility(is_backward_compatible=True)
def reduce_package_graph_module(
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
) -> torch.nn.Module:
forward = importer.import_module(generated_module_name).forward
return _deserialize_graph_module(forward, body)
@compatibility(is_backward_compatible=True)
def reduce_deploy_graph_module(
importer: PackageImporter, body: Dict[Any, Any], import_block: str
) -> torch.nn.Module:
ns = {}
ns["__builtins__"] = importer.patched_builtins
fn_src = body.get('_code')
assert fn_src is not None
forward = _forward_from_src(import_block + fn_src, ns)
return _deserialize_graph_module(forward, body)
def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
"""
Deserialize a GraphModule given the dictionary of the original module,
using the code to reconstruct the graph. We delete the actual graph before
saving the dictionary so that changes to the in-memory graph format do not
get serialized.
"""
# We create a dummy class here because symbolic_trace pulls the forward()
# function off of the class, rather than the instance
class CodeOnlyModule(torch.nn.Module):
def __init__(self, body):
super().__init__()
self.__dict__ = body
# Try to retrieve the forward source in a backward-compatible way
CodeOnlyModule.forward = forward
tracer_cls = body.get('_tracer_cls')
if tracer_cls is None:
from ._symbolic_trace import Tracer
tracer_cls = Tracer
graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
# This is a workaround for a mypy linter issue related to
# passing base class as an argument - https://github.com/python/mypy/issues/5865.
cls_tracer : Any = tracer_cls
class KeepModules(cls_tracer):
# we shouldn't trace into any of the submodules,
# because they were not traced in the original GraphModule
def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
return True
com = CodeOnlyModule(body)
tracer_extras = body.get('_tracer_extras', {})
graph = KeepModules().trace(com, **tracer_extras)
# Manually set Tracer class on the reconstructed Graph, to avoid
# referencing the private local subclass KeepModules.
graph._tracer_cls = tracer_cls
gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
# The GraphModule constructor only retains attributes referenced by the graph.
# In this case, our goal is return a GraphModule as close to identical as the one
# put into the package. If any additional attributes were present in body,
# we should keep them.
for k, v in body.items():
if not hasattr(gm, k):
setattr(gm, k, v)
return gm
# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
# This installs empty Modules where none exist yet if they are subpaths of target
def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
*prefix, field = target.split('.')
for item in prefix:
f = getattr(from_module, item)
t = getattr(to_module, item, None)
if f is t:
# we have already installed one of its parents
# (e.g. target = root.linear.weight, but we have already installed root.linear)
# once we install a parent, we no longer need to copy the children
# since all the needed properties will already be present
return
if t is None:
t = torch.nn.Module()
setattr(to_module, item, t)
from_module, to_module = f, t
orig = getattr(from_module, field)
# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
# So, we register it as a named buffer in the target module.
if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
to_module.register_buffer(field, orig)
else:
setattr(to_module, field, orig)
# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
# This installs empty Modules where none exist yet if they are subpaths of target
def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
*prefix, field = target.split('.')
for item in prefix:
t = getattr(to_module, item, None)
if t is None:
t = torch.nn.Module()
setattr(to_module, item, t)
to_module = t
# If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
# So, we register it as a named buffer in the target module.
if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter):
to_module.register_buffer(field, from_obj)
else:
setattr(to_module, field, from_obj)
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
# Previously, if an error occurred when valid
# symbolically-traced code was run with an invalid input, the
# user would see the source of the error as coming from
# `File "<eval_with_key_N">`, where N is some number. We use
# this function to generate a more informative error message. We
# return the traceback itself, a message explaining that the
# error occurred in a traced Module's generated forward
# function, and five lines of context surrounding the faulty
# line
@staticmethod
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
# auxiliary variables (for readability)
err_lineno = frame_summary.lineno
assert err_lineno is not None
line = frame_summary.line
assert line is not None
err_line_len = len(line)
all_src_lines = linecache.getlines(frame_summary.filename)
# constituent substrings of the error message
tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:")
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
def __call__(self, obj, *args, **kwargs):
try:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary),
file=sys.stderr)
raise e.with_traceback(None)
else:
raise e
@compatibility(is_backward_compatible=True)
class GraphModule(torch.nn.Module):
"""
GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
from that ``graph``.
.. warning::
When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
regenerated. However, if you edit the contents of the ``graph`` without reassigning
the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
code.
"""
def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
# each instance of a graph module needs its own forward method
# so create a new singleton class for each instance.
# it is a subclass of the user-defined class, the only difference
# is an extra layer to install the forward method
# address issue described at https://github.com/pytorch/pytorch/issues/63883
# in other words, traverse class hierarchy to fix the redundant class definition problem
for t in cls.__mro__:
c = t.__qualname__.split('.')[-1]
if c != 'GraphModuleImpl':
cls = t
break
class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
pass
return super().__new__(GraphModuleImpl)
@compatibility(is_backward_compatible=True)
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name: str = 'GraphModule'):
"""
Construct a GraphModule.
Args:
root (Union[torch.nn.Module, Dict[str, Any]):
``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
In the case that ``root`` is a Module, any references to Module-based objects (via qualified
name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
looked up directly in the dict's keys. The object mapped to by the Dict will be copied
over into the appropriate place within the GraphModule's module hierarchy.
graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
error messages will report as originating from ``GraphModule``. It may be helpful to set this
to ``root``'s original name or a name that makes sense within the context of your transform.
"""
super().__init__()
self.__class__.__name__ = class_name
if isinstance(root, torch.nn.Module):
if hasattr(root, 'training'):
self.training = root.training
for node in graph.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
elif isinstance(root, dict):
targets_to_copy = []
for node in graph.nodes:
if node.op in ['get_attr', 'call_module']:
assert isinstance(node.target, str)
if node.target not in root:
raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
' but that target was not provided in ``root``!')
targets_to_copy.append(node.target)
# Sort targets in ascending order of the # of atoms.
# This will ensure that less deeply nested attributes are assigned
# before more deeply nested attributes. For example, foo.bar
# will be assigned before foo.bar.baz. Otherwise, we might assign
# the user-provided ``foo.bar`` and wipe out the previously-assigned
# ``foo.bar.baz``
targets_to_copy.sort(key=lambda t: t.count('.'))
for target_to_copy in targets_to_copy:
_assign_attr(root[target_to_copy], self, target_to_copy)
else:
raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')
self.graph = graph
# Store the Tracer class responsible for creating a Graph separately as part of the
# GraphModule state, except when the Tracer is defined in a local namespace.
# Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization.
self._tracer_cls = None
if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
self._tracer_cls = self.graph._tracer_cls
self._tracer_extras = {}
if self.graph._tracer_extras:
self._tracer_extras = self.graph._tracer_extras
# Dictionary to store metadata
self.meta : Dict[str, Any] = {}
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
#
# Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
__jit_unused_properties__ = ['graph']
@property
def graph(self) -> Graph:
"""
Return the ``Graph`` underlying this ``GraphModule``
"""
return self._graph
@graph.setter
def graph(self, g : Graph) -> None:
"""
Set the underlying ``Graph`` for this ``GraphModule``. This will internally
recompile the ``GraphModule`` so that the generated ``forward()`` function
corresponds to ``g``
"""
assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}'
self._graph = g
g.owning_module = self
self.recompile()
@compatibility(is_backward_compatible=False)
def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"):
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
imported with ``from <folder> import <module_name>``
Args:
folder (Union[str, os.PathLike]): The folder to write the code out to
module_name (str): Top-level name to use for the ``Module`` while
writing out the code
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
tab = " " * 4
custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()])
model_str = f"""
import torch
{custom_builtins}
from torch.nn import *
class {module_name}(torch.nn.Module):
def __init__(self):
super().__init__()
"""
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
else:
return None
blobified_modules = []
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():
if buffer is None:
continue
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
for param_name, param in self._parameters.items():
if param is None:
continue
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
@compatibility(is_backward_compatible=True)
def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
"""
Adds the given submodule to ``self``.
This installs empty Modules where none exist yet if they are
subpaths of ``target``.
Args:
target: The fully-qualified string name of the new submodule
(See example in ``nn.Module.get_submodule`` for how to
specify a fully-qualified string.)
m: The submodule itself; the actual object we want to
install in the current Module
Return:
bool: Whether or not the submodule could be inserted. For
this method to return True, each object in the chain
denoted by ``target`` must either a) not exist yet,
or b) reference an ``nn.Module`` (not a parameter or
other attribute)
"""
*prefix, field = target.split('.')
mod: torch.nn.Module = self
for item in prefix:
submod = getattr(mod, item, None)
if submod is None:
submod = torch.nn.Module()
setattr(mod, item, submod)
if not isinstance(submod, torch.nn.Module):
return False
mod = submod
mod.add_module(field, m)
return True
@compatibility(is_backward_compatible=True)
def delete_submodule(self, target: str) -> bool:
"""
Deletes the given submodule from ``self``.
The module will not be deleted if ``target`` is not a valid
target.
Args:
target: The fully-qualified string name of the new submodule
(See example in ``nn.Module.get_submodule`` for how to
specify a fully-qualified string.)
Returns:
bool: Whether or not the target string referenced a
submodule we want to delete. A return value of ``False``
means that the ``target`` was not a valid reference to
a submodule.
"""
atoms = target.split(".")
path, target_submod = atoms[:-1], atoms[-1]
mod: torch.nn.Module = self
# Get the parent module
for item in path:
if not hasattr(mod, item):
return False
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
return False
if not hasattr(mod, target_submod):
return False
if not isinstance(getattr(mod, target_submod), torch.nn.Module):
return False
delattr(mod, target_submod)
return True
@compatibility(is_backward_compatible=True)
def delete_all_unused_submodules(self) -> None:
"""
Deletes all unused submodules from ``self``.
A Module is considered "used" if any one of the following is
true:
1. It has children that are used
2. Its forward is called directly via a ``call_module`` node
3. It has a non-Module attribute that is used from a
``get_attr`` node
This method can be called to clean up an ``nn.Module`` without
manually calling ``delete_submodule`` on each unused submodule.
"""
used: List[str] = []
for node in self.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":
# A list of strings representing the different parts
# of the path. For exmaple, `foo.bar.baz` gives us
# ["foo", "bar", "baz"]
fullpath = node.target.split(".")
# If we're looking at multiple parts of a path, join
# join them with a dot. Otherwise, return that single
# element without doing anything to it.
def join_fn(x: str, y: str) -> str:
return '.'.join([x, y] if y else [x])
# Progressively collect all the names of intermediate
# modules. For example, if we have the target
# `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
# `foo.bar.baz` to the list.
for path in itertools.accumulate(fullpath, join_fn):
used.append(path)
# For a `call_module` node, also register all recursive submodules
# as used
if node.op == "call_module":
try:
submod = self.get_submodule(node.target)
for submod_name, _ in submod.named_modules():
if submod_name != '':
used.append('.'.join([node.target, submod_name]))
except AttributeError:
# Node referenced nonexistent submodule, don't need to
# worry about GCing anything
pass
to_delete = [name for name, _ in self.named_modules()
if name not in used]
for name in to_delete:
self.delete_submodule(name)
@property
def code(self) -> str:
"""
Return the Python code generated from the ``Graph`` underlying this
``GraphModule``.
"""
if not hasattr(self, '_code'):
raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
return self._code
@compatibility(is_backward_compatible=True)
def recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
self._code = python_code.src
cls = type(self)
cls.forward = _forward_from_src(self._code, python_code.globals)
# Determine whether this class explicitly defines a __call__ implementation
# to wrap. If it does, save it in order to have wrapped_call invoke it.
# If it does not, wrapped_call can use a dynamic call to super() instead.
# In most cases, super().__call__ should be torch.nn.Module.__call__.
# We do not want to hold a reference to Module.__call__ here; doing so will
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
cls.__call__ = call_wrapped
return python_code
# Passing Tracer as argument allows subclasses extending fx.GraphModule
# define their own Tracer (extending fx.Tracer).
def __reduce_deploy__(self, importer: Importer):
dict_without_graph = self.__dict__.copy()
dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
del dict_without_graph['_graph']
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, importer)
return (reduce_deploy_graph_module, (dict_without_graph, import_block))
def __reduce_package__(self, exporter: PackageExporter):
dict_without_graph = self.__dict__.copy()
dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
del dict_without_graph['_graph']
generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, exporter.importer)
module_code = import_block + self.code
exporter.save_source_string(generated_module_name, module_code)
return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
def __reduce__(self):
"""
Serialization of GraphModule. We serialize only the generated code, not
the underlying ``Graph``. This is because ``Graph`` does not have on-disk
backward-compatibility guarantees, whereas Python source code does.
On the deserialization side, we symbolically trace through the generated
code to regenerate the underlying ``Graph``
"""
dict_without_graph = self.__dict__.copy()
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, sys_importer)
del dict_without_graph['_graph']
return (reduce_graph_module, (dict_without_graph, import_block))
# because __reduce__ is defined for serialization,
# we need to define deepcopy otherwise it will call __reduce__
# and cause symbolic tracing to occur every time we try to copy the object
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return GraphModule(fake_mod, fake_mod.__dict__['_graph'])
def __copy__(self):
return GraphModule(self, self.graph)
@compatibility(is_backward_compatible=False)
def print_readable(self):
"""
Return the Python code generated for current GraphModule and its children GraphModules
"""
verbose_python_code = self._graph.python_code(root_module='self', verbose=True)
module_code = verbose_python_code.src
module_code = module_code.lstrip('\n')
module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
module_code = _addindent(module_code, 4)
submodule_code_list = [""]
for submodule in self.children():
if isinstance(submodule, GraphModule):
submodule_code_list.append(submodule.__nested_code())
submodule_code = "\n".join(submodule_code_list)
submodule_code = _addindent(submodule_code, 4)
print(module_code + submodule_code)
def __str__(self) -> str:
orig_str = super().__str__()
print_readable_reminder = "# To see more debug info, please use `graph_module.print_readable()`"
return '\n'.join([orig_str, self._code, print_readable_reminder])
def _replicate_for_data_parallel(self):
new_gm = self.__copy__()
new_gm._is_replica = True
return new_gm
# workarounds for issues in __torch_function__
# WAR for __torch_function__ not handling tensor lists,
# fix is in https://github.com/pytorch/pytorch/pull/34725
# orig_cat = torch.cat
# def patched_cat(*args, **kwargs):
# tensors = args[0]
# for t in tensors:
# if isinstance(t, Proxy):
# return t.__torch_function__(patched_cat, (), args, kwargs)
# return orig_cat(*args, **kwargs)
# patched_cat.__module__ = 'torch'
# patched_cat.__name__ = 'cat'
# torch.cat = patched_cat
|