| | """ |
| | This module provides Source classes that track the origins of values in PyTorch Dynamo. |
| | Sources represent where values come from (e.g. local variables, globals, attributes) and |
| | are used for guard generation and code reconstruction during compilation. |
| | |
| | The module includes specialized sources for: |
| | - Local variables and synthetic locals |
| | - Global variables and constants |
| | - Object attributes and method calls |
| | - NN module specialization (specialized vs unspecialized) |
| | - Random values and tensor properties |
| | - Default argument handling |
| | - FSDP (Fully Sharded Data Parallel) modules |
| | |
| | Sources play a key role in Dynamo's guard system by tracking value origins for |
| | guard generation, and in code reconstruction by providing methods to rebuild |
| | the code needed to recreate values. |
| | """ |
| |
|
| | import dataclasses |
| | import enum |
| | import functools |
| | from typing import Any, Callable, Optional, TYPE_CHECKING, Union |
| |
|
| | from torch._guards import ChainedSource, Guard, GuardSource, Source |
| |
|
| | from . import utils |
| | from .bytecode_transformation import create_call_function, create_instruction |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from .codegen import PyCodegen |
| |
|
| | |
| | |
| |
|
| | |
| | _GUARD_SOURCE_SPECIALIZED_NN_MODULE = { |
| | GuardSource.LOCAL: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, |
| | GuardSource.GLOBAL: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, |
| | GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, |
| | GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, |
| | |
| | GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, |
| | GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | } |
| |
|
| | |
| | _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE = { |
| | GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, |
| | GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, |
| | GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, |
| | |
| | GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, |
| | GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, |
| | |
| | GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | } |
| |
|
| | |
| | _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE = { |
| | GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | |
| | GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, |
| | GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | } |
| |
|
| | _GUARD_SOURCE_FSDP_MODULE = { |
| | GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE, |
| | GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| | GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| | } |
| |
|
| |
|
| | def is_constant_source(source: Source) -> bool: |
| | if isinstance(source, ConstantSource): |
| | return True |
| | try: |
| | if source.guard_source() == GuardSource.CONSTANT: |
| | return True |
| | except NotImplementedError: |
| | pass |
| |
|
| | return False |
| |
|
| |
|
| | def _get_source_debug_name(source: Source) -> str: |
| | try: |
| | return source.name() |
| | except NotImplementedError: |
| | return "<unknown source>" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class LocalSource(Source): |
| | local_name: str |
| |
|
| | |
| | is_input: bool = False |
| |
|
| | |
| | |
| | dynamism: Optional[frozenset[str]] = None |
| |
|
| | |
| | |
| | |
| | is_derefed_cell_contents: bool = False |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | if self.is_derefed_cell_contents: |
| | codegen.load_deref(self.local_name) |
| | else: |
| | codegen.append_output(codegen.create_load(self.local_name)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.LOCAL |
| |
|
| | def name(self) -> str: |
| | return f"L[{repr(self.local_name)}]" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class SyntheticLocalSource(Source): |
| | local_name: str |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.append_output(codegen.create_load(self.local_name)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.SYNTHETIC_LOCAL |
| |
|
| | def name(self) -> str: |
| | return f"SYNTHETIC_LOCAL[{self.local_name!r}]" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class RandomValueSource(Source): |
| | random_call_index: int |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.RANDOM_VALUE |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) |
| | codegen.append_output(codegen.create_load_const(self.random_call_index)) |
| | codegen.append_output(create_instruction("BINARY_SUBSCR")) |
| |
|
| | def name(self) -> str: |
| | return f"random_value_{self.random_call_index}" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class GlobalSource(Source): |
| | global_name: str |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.append_output(codegen.create_load_global(self.global_name, add=True)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.GLOBAL |
| |
|
| | def name(self) -> str: |
| | return f"G[{repr(self.global_name)}]" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class GlobalWeakRefSource(Source): |
| | global_name: str |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.append_output( |
| | codegen.create_load_global(self.global_name, add=True) |
| | ) |
| | ) |
| | codegen.extend_output(create_call_function(0, False)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.GLOBAL |
| |
|
| | def name(self) -> str: |
| | return f"G[{repr(self.global_name)}]()" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class WeakRefCallSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null(lambda: codegen(self.base)) |
| | codegen.extend_output(create_call_function(0, False)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}()" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class CallFunctionNoArgsSource(WeakRefCallSource): |
| | pass |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class AttrSource(ChainedSource): |
| | member: str |
| |
|
| | def __post_init__(self) -> None: |
| | assert self.base, "Can't construct an AttrSource without a valid base source" |
| | if "." in self.member: |
| | member_parts = self.member.split(".") |
| | object.__setattr__( |
| | self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) |
| | ) |
| | object.__setattr__(self, "member", member_parts[-1]) |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs(self.member)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | if not self.member.isidentifier(): |
| | return f"getattr({self.base.name()}, {self.member!r})" |
| | return f"{self.base.name()}.{self.member}" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class GenericAttrSource(ChainedSource): |
| | member: str |
| |
|
| | def __post_init__(self) -> None: |
| | assert self.base, "Can't construct an AttrSource without a valid base source" |
| | if "." in self.member: |
| | member_parts = self.member.split(".") |
| | object.__setattr__( |
| | self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) |
| | ) |
| | object.__setattr__(self, "member", member_parts[-1]) |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs(self.member)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"object.__getattribute__({self.base.name()}, {self.member!r})" |
| |
|
| |
|
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class TypeDictSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs("__dict__")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | |
| | |
| | |
| | |
| | return f"dict({self.base.name()}.__dict__)" |
| |
|
| |
|
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class TypeMROSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs("__mro__")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}.__mro__" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class LocalCellSource(Source): |
| | """ |
| | Conceptually, this class is `LocalSource` for cell objects implicitly |
| | generated by Python (e.g., captured variables). |
| | """ |
| |
|
| | local_name: str |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | |
| | |
| | |
| | codegen.append_output(codegen.create_load_closure(self.local_name)) |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class CodeSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs("__code__")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}.__code__" |
| |
|
| |
|
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class ClosureSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs("__closure__")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}.__closure__" |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class GradSource(ChainedSource): |
| | member: str = "grad" |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs(self.member)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}.{self.member}" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class ParamBufferSource(AttrSource): |
| | def guard_source(self) -> GuardSource: |
| | return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] |
| |
|
| |
|
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class UnspecializedParamBufferSource(AttrSource): |
| | pass |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class EphemeralSource(Source): |
| | desc: Optional[str] = None |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.EPHEMERAL |
| |
|
| | def name(self) -> str: |
| | return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>" |
| |
|
| | def make_guard(self, fn: Callable[..., Any]) -> Guard: |
| | raise NotImplementedError |
| |
|
| | def is_ephemeral(self) -> bool: |
| | return True |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class SkipGuardSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | self.base.reconstruct(codegen) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return self.base.name() |
| |
|
| |
|
| | class TensorProperty(enum.Enum): |
| | SIZE = 0 |
| | STRIDE = 1 |
| | STORAGE_OFFSET = 2 |
| |
|
| | def method_name(self) -> str: |
| | if self is TensorProperty.SIZE: |
| | return "size" |
| | elif self is TensorProperty.STRIDE: |
| | return "stride" |
| | elif self is TensorProperty.STORAGE_OFFSET: |
| | return "storage_offset" |
| | else: |
| | raise AssertionError(f"unhandled {self}") |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class TensorPropertySource(ChainedSource): |
| | prop: TensorProperty |
| | idx: Optional[int] = None |
| |
|
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| | if self.prop is TensorProperty.STORAGE_OFFSET: |
| | assert self.idx is None |
| | else: |
| | assert self.idx is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from( |
| | utils.__name__, f"call_{self.prop.method_name()}" |
| | ) |
| | ) |
| | codegen(self.base) |
| |
|
| | if self.idx is not None: |
| | codegen.append_output(codegen.create_load_const(self.idx)) |
| | codegen.extend_output( |
| | create_call_function(2 if self.idx is not None else 1, False) |
| | ) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | if self.prop is TensorProperty.SIZE: |
| | return f"{self.base.name()}.size()[{self.idx}]" |
| | elif self.prop is TensorProperty.STRIDE: |
| | return f"{self.base.name()}.stride()[{self.idx}]" |
| | elif self.prop is TensorProperty.STORAGE_OFFSET: |
| | assert self.idx is None |
| | return f"{self.base.name()}.storage_offset()" |
| | else: |
| | raise AssertionError(f"unhandled {self.prop}") |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class IndexedSource(ChainedSource): |
| | idx: int |
| |
|
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | raise NotImplementedError |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"({self.idx}, {self.base.name()})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class NegateSource(ChainedSource): |
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | raise NotImplementedError |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | |
| | return f"{self.base.name()}.__neg__()" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class ConvertIntSource(ChainedSource): |
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"cast_symbool_to_symint_guardless({self.base.name()})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class FlattenScriptObjectSource(ChainedSource): |
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}.__obj_flatten__()" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class ScriptObjectQualifiedNameSource(ChainedSource): |
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}._type().qualified_name()" |
| |
|
| |
|
| | class AttrProxySource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"{self.base.name()}.get_base()" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class DefaultsSource(ChainedSource): |
| | idx_key: Union[int, str] |
| | is_kw: bool = False |
| | field: str = dataclasses.field(init=False, repr=False, compare=False) |
| | _name: str = dataclasses.field(init=False, repr=False, compare=False) |
| |
|
| | def __post_init__(self) -> None: |
| | assert self.base, ( |
| | "Base must be a valid source in order to properly track and guard this Defaults to its origin." |
| | ) |
| | if self.is_kw: |
| | assert isinstance(self.idx_key, str) |
| | object.__setattr__(self, "field", "__kwdefaults__") |
| | object.__setattr__( |
| | self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" |
| | ) |
| | else: |
| | assert isinstance(self.idx_key, int) |
| | object.__setattr__(self, "field", "__defaults__") |
| | object.__setattr__( |
| | self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" |
| | ) |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs(self.field)) |
| | codegen.append_output(codegen.create_load_const(self.idx_key)) |
| | codegen.append_output(create_instruction("BINARY_SUBSCR")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return self._name |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class GetItemSource(ChainedSource): |
| | index: Any |
| | index_is_slice: bool = False |
| |
|
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| | if isinstance(self.index, slice): |
| | |
| | super().__setattr__("index", self.index.__reduce__()) |
| | super().__setattr__("index_is_slice", True) |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | if self.index_is_slice: |
| | codegen.append_output(codegen.create_load_const(self.unpack_slice())) |
| | else: |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| | codegen.append_output(create_instruction("BINARY_SUBSCR")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def unpack_slice(self) -> slice: |
| | assert self.index_is_slice |
| | slice_class, slice_args = self.index |
| | return slice_class(*slice_args) |
| |
|
| | def name(self) -> str: |
| | |
| | |
| | |
| | assert not isinstance(self.index, Source) |
| | if self.index_is_slice: |
| | return f"{self.base.name()}[{self.unpack_slice()!r}]" |
| | else: |
| | return f"{self.base.name()}[{self.index!r}]" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class ConstDictKeySource(ChainedSource): |
| | index: Any |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") |
| | ) |
| | codegen(self.base) |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| | codegen.extend_output(create_call_function(2, False)) |
| |
|
| | def name(self) -> str: |
| | |
| | return f"list(dict.keys({self.base.name()}))[{self.index!r}]" |
| |
|
| | def is_dict_key(self) -> bool: |
| | return True |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class NonSerializableSetGetItemSource(ChainedSource): |
| | index: int |
| |
|
| | def __post_init__(self) -> None: |
| | from .variables import ConstantVariable |
| |
|
| | assert ConstantVariable.is_literal(self.index) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from(utils.__name__, "set_getitem") |
| | ) |
| | codegen(self.base) |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| | codegen.extend_output(create_call_function(2, False)) |
| |
|
| | def name(self) -> str: |
| | |
| | return f"list({self.base.name()})[{self.index!r}]" |
| |
|
| | def is_dict_key(self) -> bool: |
| | return False |
| |
|
| |
|
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class DictGetItemSource(ChainedSource): |
| | |
| | |
| | |
| | index: Any |
| |
|
| | def __post_init__(self) -> None: |
| | from .variables import ConstantVariable |
| |
|
| | assert isinstance( |
| | self.index, ConstDictKeySource |
| | ) or ConstantVariable.is_literal(self.index) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | |
| | codegen(self.base) |
| |
|
| | |
| | if isinstance(self.index, Source): |
| | codegen(self.index) |
| | else: |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| | codegen.append_output(create_instruction("BINARY_SUBSCR")) |
| |
|
| | def name(self) -> str: |
| | if isinstance(self.index, ConstDictKeySource): |
| | return f"{self.base.name()}[{self.index.name()}]" |
| | else: |
| | return f"{self.base.name()}[{self.index!r}]" |
| |
|
| |
|
| | |
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class DictSubclassGetItemSource(ChainedSource): |
| | |
| | |
| | |
| | index: Any |
| |
|
| | def __post_init__(self) -> None: |
| | from .variables import ConstantVariable |
| |
|
| | assert isinstance( |
| | self.index, ConstDictKeySource |
| | ) or ConstantVariable.is_literal(self.index) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | |
| |
|
| | |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from(utils.__name__, "dict_getitem") |
| | ) |
| |
|
| | |
| | codegen(self.base) |
| |
|
| | |
| | if isinstance(self.index, Source): |
| | codegen(self.index) |
| | else: |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| |
|
| | codegen.extend_output(create_call_function(2, False)) |
| |
|
| | def name(self) -> str: |
| | if isinstance(self.index, ConstDictKeySource): |
| | return f"dict.__getitem__({self.base.name()}, {self.index.name()})" |
| | else: |
| | return f"{self.base.name()}[{self.index!r}]" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class ListGetItemSource(GetItemSource): |
| | """ |
| | Same as GetItemSource with reconstruct and name overridden to be list specific. |
| | """ |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | |
| | |
| |
|
| | |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from(utils.__name__, "list_getitem") |
| | ) |
| |
|
| | |
| | codegen(self.base) |
| |
|
| | |
| | if self.index_is_slice: |
| | raise RuntimeError( |
| | "List[slice] is a temporary object and should not have a source" |
| | ) |
| | else: |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| |
|
| | codegen.extend_output(create_call_function(2, False)) |
| |
|
| | def name(self) -> str: |
| | |
| | |
| | |
| | assert not isinstance(self.index, Source) |
| | if self.index_is_slice: |
| | raise RuntimeError( |
| | "List[slice] is a temporary object and should not have a source" |
| | ) |
| | else: |
| | return f"list.__getitem__({self.base.name()}, {self.index!r})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class TupleIteratorGetItemSource(GetItemSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") |
| | ) |
| | codegen(self.base) |
| | codegen.append_output(codegen.create_load_const(self.index)) |
| | codegen.extend_output(create_call_function(2, False)) |
| |
|
| | def name(self) -> str: |
| | return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class NamedTupleFieldsSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| | codegen.extend_output(codegen.create_load_attrs("_fields")) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"___namedtuple_fields({self.base.name()})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class DataclassFieldsSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from(utils.__name__, "dataclass_fields") |
| | ) |
| | codegen(self.base) |
| | codegen.extend_output(create_call_function(1, False)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"___dataclass_fields({self.base.name()})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class TypeSource(ChainedSource): |
| | def __post_init__(self) -> None: |
| | assert self.base is not None |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) |
| | codegen(self.base) |
| | codegen.extend_output(create_call_function(1, False)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return f"type({self.base.name()})" |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class OptimizerSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def name(self) -> str: |
| | return self.base.name() |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class NNModuleSource(ChainedSource): |
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen(self.base) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] |
| |
|
| | def name(self) -> str: |
| | return self.base.name() |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class UnspecializedNNModuleSource(NNModuleSource): |
| | def guard_source(self) -> GuardSource: |
| | return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): |
| | def guard_source(self) -> GuardSource: |
| | return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class FSDPNNModuleSource(NNModuleSource): |
| | def guard_source(self) -> GuardSource: |
| | return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class GlobalStateSource(Source): |
| | def name(self) -> str: |
| | return "" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.GLOBAL |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class TorchSource(Source): |
| | """Points to the actual `torch` module - used instead of GlobalSource |
| | in case the user has overridden `torch` in their local namespace""" |
| |
|
| | def __init__(self, *args: Any, **kwargs: Any) -> None: |
| | super().__init__(*args, **kwargs) |
| | from .guards import GuardBuilder, install_guard |
| |
|
| | install_guard(self.make_guard(GuardBuilder.ID_MATCH)) |
| |
|
| | def name(self) -> str: |
| | return "__import__('torch')" |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.extend_output( |
| | [ |
| | codegen.create_load_const(0), |
| | create_instruction("BUILD_TUPLE", arg=0), |
| | codegen.create_import_name("torch"), |
| | ] |
| | ) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.GLOBAL |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class TorchFunctionModeStackSource(Source): |
| | ind: int |
| |
|
| | def name(self) -> str: |
| | return f"___get_torch_function_mode_stack_at({self._get_index()})" |
| |
|
| | def _get_index(self) -> int: |
| | from .variables.torch_function import TorchFunctionModeStackVariable |
| |
|
| | return TorchFunctionModeStackVariable.get_mode_index(self.ind) |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null( |
| | lambda: codegen.load_import_from( |
| | utils.__name__, "get_torch_function_mode_stack_at" |
| | ) |
| | ) |
| | codegen.extend_output([codegen.create_load_const(self._get_index())]) |
| | codegen.extend_output(create_call_function(1, False)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.GLOBAL |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class ConstantSource(Source): |
| | source_name: str |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.append_output(codegen.create_load_global(self.source_name, add=False)) |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.CONSTANT |
| |
|
| | def name(self) -> str: |
| | return self.source_name |
| |
|
| | def make_guard(self, fn: Any) -> Any: |
| | raise NotImplementedError |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class NumpyTensorSource(ChainedSource): |
| | def name(self) -> str: |
| | return f"___from_numpy({self.base.name()})" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| | def reconstruct(self, codegen: "PyCodegen") -> None: |
| | codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) |
| | codegen(self.base) |
| | codegen.extend_output(create_call_function(1, False)) |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class SubclassAttrListSource(ChainedSource): |
| | def name(self) -> str: |
| | return f"{self.base.name()}.__tensor_flatten__()[0]" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| |
|
| | |
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class FloatTensorSource(ChainedSource): |
| | def name(self) -> str: |
| | return f"___as_tensor({self.base.name()})" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class CallMethodItemSource(ChainedSource): |
| | def name(self) -> str: |
| | return f"{self.base.name()}.item()" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return self.base.guard_source() |
| |
|
| |
|
| | |
| | |
| | |
| | @dataclasses.dataclass(frozen=True) |
| | class ShapeEnvSource(Source): |
| | def name(self) -> str: |
| | return "" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.SHAPE_ENV |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class BackwardStateSource(Source): |
| | def name(self) -> str: |
| | return "" |
| |
|
| | def guard_source(self) -> GuardSource: |
| | return GuardSource.BACKWARD_STATE |
| |
|
| |
|
| | def get_local_source_name( |
| | source: Source, *, only_allow_input: bool = False |
| | ) -> Optional[str]: |
| | if isinstance(source, ChainedSource): |
| | return get_local_source_name(source.base, only_allow_input=only_allow_input) |
| | if not isinstance(source, LocalSource): |
| | return None |
| | if only_allow_input and not source.is_input: |
| | return None |
| | return source.local_name |
| |
|
| |
|
| | def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool: |
| | return get_local_source_name(source, only_allow_input=only_allow_input) is not None |
| |
|
| |
|
| | def is_from_global_source(source: Source) -> bool: |
| | return get_global_source_name(source) is not None |
| |
|
| |
|
| | def get_global_source_name(source: Source) -> Optional[str]: |
| | if isinstance(source, ChainedSource): |
| | return get_global_source_name(source.base) |
| | if not isinstance(source, GlobalSource): |
| | return None |
| | return source.global_name |
| |
|
| |
|
| | def is_from_nonlocal_source(source: Source) -> bool: |
| | if isinstance(source, ChainedSource): |
| | return is_from_nonlocal_source(source.base) |
| | return ( |
| | isinstance(source, LocalSource) |
| | and source.is_derefed_cell_contents |
| | and not source.is_input |
| | ) |
| |
|
| |
|
| | def is_from_closure_source(source: Source) -> bool: |
| | if isinstance(source, ClosureSource): |
| | return True |
| | if isinstance(source, ChainedSource): |
| | return is_from_closure_source(source.base) |
| | return False |
| |
|
| |
|
| | def is_from_source(source: Source, target: Source) -> bool: |
| | if isinstance(source, ChainedSource): |
| | return is_from_source(source.base, target) |
| | return source == target |
| |
|
| |
|
| | @functools.lru_cache |
| | def is_from_unspecialized_nn_module_source(source: Source) -> bool: |
| | if isinstance(source, UnspecializedNNModuleSource): |
| | return True |
| | if isinstance(source, ChainedSource): |
| | return is_from_unspecialized_nn_module_source(source.base) |
| | return False |
| |
|
| |
|
| | @functools.lru_cache |
| | def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool: |
| | if isinstance(source, UnspecializedBuiltinNNModuleSource): |
| | return True |
| | if isinstance(source, ChainedSource): |
| | return is_from_unspecialized_builtin_nn_module_source(source.base) |
| | return False |
| |
|
| |
|
| | @functools.lru_cache |
| | def is_from_unspecialized_param_buffer_source(source: Source) -> bool: |
| | if isinstance(source, UnspecializedParamBufferSource): |
| | return True |
| | if isinstance(source, ChainedSource): |
| | return is_from_unspecialized_param_buffer_source(source.base) |
| | return False |
| |
|
| |
|
| | @functools.lru_cache |
| | def is_from_flatten_script_object_source(source: Source) -> bool: |
| | if isinstance(source, FlattenScriptObjectSource): |
| | return True |
| | elif isinstance(source, ChainedSource): |
| | return is_from_flatten_script_object_source(source.base) |
| | return False |
| |
|
| |
|
| | @functools.lru_cache |
| | def is_from_optimizer_source(source: Source) -> bool: |
| | if isinstance(source, OptimizerSource): |
| | return True |
| | if isinstance(source, ChainedSource): |
| | return is_from_optimizer_source(source.base) |
| | return False |
| |
|
| |
|
| | |
| | |
| | @functools.lru_cache |
| | def is_from_defaults(source: Source) -> bool: |
| | if isinstance(source, DefaultsSource): |
| | return True |
| |
|
| | |
| | if ( |
| | isinstance(source, DictGetItemSource) |
| | and isinstance(source.base, AttrSource) |
| | and source.base.member == "__kwdefaults__" |
| | ): |
| | return True |
| |
|
| | |
| | if ( |
| | isinstance(source, GetItemSource) |
| | and isinstance(source.base, AttrSource) |
| | and source.base.member == "__defaults__" |
| | ): |
| | return True |
| |
|
| | if isinstance(source, ChainedSource): |
| | return is_from_defaults(source.base) |
| | return False |
| |
|
| |
|
| | @functools.lru_cache |
| | def is_from_skip_guard_source(source: Source) -> bool: |
| | if isinstance(source, SkipGuardSource): |
| | return True |
| |
|
| | if isinstance(source, ChainedSource): |
| | return is_from_skip_guard_source(source.base) |
| |
|
| | return False |
| |
|