|  |  | 
					
						
						|  | import ast | 
					
						
						|  | import inspect | 
					
						
						|  | import sys | 
					
						
						|  | import textwrap | 
					
						
						|  | import torch | 
					
						
						|  | import warnings | 
					
						
						|  |  | 
					
						
						|  | class AttributeTypeIsSupportedChecker(ast.NodeVisitor): | 
					
						
						|  | """ | 
					
						
						|  | Checks the ``__init__`` method of a given ``nn.Module`` to ensure | 
					
						
						|  | that all instance-level attributes can be properly initialized. | 
					
						
						|  |  | 
					
						
						|  | Specifically, we do type inference based on attribute values...even | 
					
						
						|  | if the attribute in question has already been typed using | 
					
						
						|  | Python3-style annotations or ``torch.jit.annotate``. This means that | 
					
						
						|  | setting an instance-level attribute to ``[]`` (for ``List``), | 
					
						
						|  | ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough | 
					
						
						|  | information for us to properly initialize that attribute. | 
					
						
						|  |  | 
					
						
						|  | An object of this class can walk a given ``nn.Module``'s AST and | 
					
						
						|  | determine if it meets our requirements or not. | 
					
						
						|  |  | 
					
						
						|  | Known limitations | 
					
						
						|  | 1. We can only check the AST nodes for certain constructs; we can't | 
					
						
						|  | ``eval`` arbitrary expressions. This means that function calls, | 
					
						
						|  | class instantiations, and complex expressions that resolve to one of | 
					
						
						|  | the "empty" values specified above will NOT be flagged as | 
					
						
						|  | problematic. | 
					
						
						|  | 2. We match on string literals, so if the user decides to use a | 
					
						
						|  | non-standard import (e.g. `from typing import List as foo`), we | 
					
						
						|  | won't catch it. | 
					
						
						|  |  | 
					
						
						|  | Example: | 
					
						
						|  |  | 
					
						
						|  | .. code-block:: python | 
					
						
						|  |  | 
					
						
						|  | class M(torch.nn.Module): | 
					
						
						|  | def fn(self): | 
					
						
						|  | return [] | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.x: List[int] = [] | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: List[int]): | 
					
						
						|  | self.x = x | 
					
						
						|  | return 1 | 
					
						
						|  |  | 
					
						
						|  | The above code will pass the ``AttributeTypeIsSupportedChecker`` | 
					
						
						|  | check since we have a function call in ``__init__``. However, | 
					
						
						|  | it will still fail later with the ``RuntimeError`` "Tried to set | 
					
						
						|  | nonexistent attribute: x. Did you forget to initialize it in | 
					
						
						|  | __init__()?". | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | nn_module - The instance of ``torch.nn.Module`` whose | 
					
						
						|  | ``__init__`` method we wish to check | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def check(self, nn_module: torch.nn.Module) -> None: | 
					
						
						|  |  | 
					
						
						|  | self.using_deprecated_ast: bool = sys.version_info < (3, 8) | 
					
						
						|  |  | 
					
						
						|  | source_lines = inspect.getsource(nn_module.__class__.__init__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def is_useless_comment(line): | 
					
						
						|  | line = line.strip() | 
					
						
						|  | return line.startswith("#") and not line.startswith("# type:") | 
					
						
						|  | source_lines = "\n".join([l for l in source_lines.split("\n") if not is_useless_comment(l)]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | init_ast = ast.parse(textwrap.dedent(source_lines)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.class_level_annotations = list(nn_module.__annotations__.keys()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.visiting_class_level_ann = False | 
					
						
						|  |  | 
					
						
						|  | self.visit(init_ast) | 
					
						
						|  |  | 
					
						
						|  | def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool: | 
					
						
						|  | if ann_type == "List": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not isinstance(node, ast.List): | 
					
						
						|  | return False | 
					
						
						|  | if node.elts: | 
					
						
						|  | return False | 
					
						
						|  | elif ann_type == "Dict": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not isinstance(node, ast.Dict): | 
					
						
						|  | return False | 
					
						
						|  | if node.keys: | 
					
						
						|  | return False | 
					
						
						|  | elif ann_type == "Optional": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (not self.using_deprecated_ast | 
					
						
						|  | and not isinstance(node, ast.Constant)): | 
					
						
						|  | return False | 
					
						
						|  | if (self.using_deprecated_ast | 
					
						
						|  | and not isinstance(node, ast.NameConstant)): | 
					
						
						|  | return False | 
					
						
						|  | if node.value: | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  | def visit_Assign(self, node): | 
					
						
						|  | """ | 
					
						
						|  | If we're visiting a Call Node (the right-hand side of an | 
					
						
						|  | assignment statement), we won't be able to check the variable | 
					
						
						|  | that we're assigning to (the left-hand side of an assignment). | 
					
						
						|  | Because of this, we need to store this state in visitAssign. | 
					
						
						|  | (Luckily, we only have to do this if we're assigning to a Call | 
					
						
						|  | Node, i.e. ``torch.jit.annotate``. If we're using normal Python | 
					
						
						|  | annotations, we'll be visiting an AnnAssign Node, which has its | 
					
						
						|  | target built in.) | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | if (isinstance(node.value, ast.Call) | 
					
						
						|  | and node.targets[0].attr in self.class_level_annotations): | 
					
						
						|  | self.visiting_class_level_ann = True | 
					
						
						|  | except AttributeError: | 
					
						
						|  | return | 
					
						
						|  | self.generic_visit(node) | 
					
						
						|  | self.visiting_class_level_ann = False | 
					
						
						|  |  | 
					
						
						|  | def visit_AnnAssign(self, node): | 
					
						
						|  | """ | 
					
						
						|  | Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` | 
					
						
						|  | method and see if it conforms to our attribute annotation rules. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if node.target.value.id != "self": | 
					
						
						|  | return | 
					
						
						|  | except AttributeError: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if node.target.attr in self.class_level_annotations: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | containers = {"List", "Dict", "Optional"} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if node.annotation.value.id not in containers: | 
					
						
						|  | return | 
					
						
						|  | except AttributeError: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ann_type = node.annotation.value.id | 
					
						
						|  | if not self._is_empty_container(node.value, ann_type): | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | warnings.warn("The TorchScript type system doesn't support " | 
					
						
						|  | "instance-level annotations on empty non-base " | 
					
						
						|  | "types in `__init__`. Instead, either 1) use a " | 
					
						
						|  | "type annotation in the class body, or 2) wrap " | 
					
						
						|  | "the type in `torch.jit.Attribute`.") | 
					
						
						|  |  | 
					
						
						|  | def visit_Call(self, node): | 
					
						
						|  | """ | 
					
						
						|  | Visit a Call node in an ``nn.Module``'s ``__init__`` | 
					
						
						|  | method and determine if it's ``torch.jit.annotate``. If so, | 
					
						
						|  | see if it conforms to our attribute annotation rules. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.visiting_class_level_ann: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if (node.func.value.value.id != "torch" | 
					
						
						|  | or node.func.value.attr != "jit" | 
					
						
						|  | or node.func.attr != "annotate"): | 
					
						
						|  | self.generic_visit(node) | 
					
						
						|  | elif (node.func.value.value.id != "jit" | 
					
						
						|  | or node.func.value.attr != "annotate"): | 
					
						
						|  | self.generic_visit(node) | 
					
						
						|  | except AttributeError: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.generic_visit(node) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(node.args) != 2: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | if not isinstance(node.args[0], ast.Subscript): | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | containers = {"List", "Dict", "Optional"} | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | ann_type = node.args[0].value.id | 
					
						
						|  | except AttributeError: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | if ann_type not in containers: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not self._is_empty_container(node.args[1], ann_type): | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | warnings.warn("The TorchScript type system doesn't support " | 
					
						
						|  | "instance-level annotations on empty non-base " | 
					
						
						|  | "types in `__init__`. Instead, either 1) use a " | 
					
						
						|  | "type annotation in the class body, or 2) wrap " | 
					
						
						|  | "the type in `torch.jit.Attribute`.") | 
					
						
						|  |  |