| | import copy |
| | import dataclasses |
| | import logging |
| | from abc import ABC, abstractmethod |
| | from collections import defaultdict |
| | from collections.abc import Generator |
| | from contextlib import contextmanager |
| | from itertools import chain |
| | from typing import Any, Optional |
| |
|
| | from torch.utils._appending_byte_serializer import ( |
| | AppendingByteSerializer, |
| | BytesReader, |
| | BytesWriter, |
| | ) |
| | from torch.utils._ordered_set import OrderedSet |
| |
|
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class CacheArtifact(ABC): |
| | """ |
| | Data for each cache artifact that will be serialized and deserialized |
| | """ |
| |
|
| | key: str |
| | content: bytes = dataclasses.field(repr=False) |
| |
|
| | @staticmethod |
| | def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None: |
| | writer.write_str(cls.key) |
| | writer.write_bytes(cls.content) |
| |
|
| | @staticmethod |
| | def deserialize(artifact_type: str, reader: BytesReader) -> "CacheArtifact": |
| | key = reader.read_str() |
| | content = reader.read_bytes() |
| | return CacheArtifactFactory.create(artifact_type, key, content) |
| |
|
| | @staticmethod |
| | def encode(content: Any) -> bytes: |
| | assert isinstance(content, bytes), f"Expected bytes, got {type(content)}" |
| | return content |
| |
|
| | @abstractmethod |
| | def populate_cache(self) -> None: |
| | pass |
| |
|
| | def precompile_compatible(self) -> bool: |
| | return False |
| |
|
| | @staticmethod |
| | def type() -> str: |
| | """ |
| | Returns the type of the artifact. Must be unique across all CacheArtifact classes. |
| | |
| | CacheArtifactFactory.register will add property method to CacheInfo based on this (def {type}_artifacts) |
| | that returns all artifacts for specific cache. |
| | """ |
| | raise RuntimeError("CacheArtifact is an abstract class, please use a subclass") |
| |
|
| |
|
| | class CacheArtifactFactory: |
| | """ |
| | Factory for creating CacheArtifact objects based on their type |
| | """ |
| |
|
| | _artifact_types: dict[str, type[CacheArtifact]] = {} |
| |
|
| | @classmethod |
| | def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: |
| | artifact_type_key = artifact_cls.type() |
| | assert artifact_cls.type() not in cls._artifact_types, ( |
| | f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" |
| | ) |
| | cls._artifact_types[artifact_type_key] = artifact_cls |
| | setattr( |
| | CacheInfo, |
| | f"{artifact_type_key}_artifacts", |
| | property(lambda self: self.artifacts[artifact_type_key]), |
| | ) |
| | return artifact_cls |
| |
|
| | @classmethod |
| | def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]: |
| | assert artifact_type_key in cls._artifact_types, ( |
| | f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" |
| | ) |
| | return cls._artifact_types[artifact_type_key] |
| |
|
| | @classmethod |
| | def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact: |
| | artifact_cls = cls._get_artifact_type(artifact_type_key) |
| | return artifact_cls(key, content) |
| |
|
| | @classmethod |
| | def encode_create( |
| | cls, artifact_type_key: str, key: str, content: Any |
| | ) -> CacheArtifact: |
| | artifact_cls = cls._get_artifact_type(artifact_type_key) |
| | return artifact_cls(key, artifact_cls.encode(content)) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class CacheInfo: |
| | """ |
| | Return value of serialization and deserialization for the purpose of |
| | instrumentation |
| | """ |
| |
|
| | artifacts: defaultdict[str, list[str]] = dataclasses.field( |
| | default_factory=lambda: defaultdict(list) |
| | ) |
| |
|
| | |
| | @property |
| | def inductor_artifacts(self) -> list[str]: |
| | ... |
| |
|
| | @property |
| | def autotune_artifacts(self) -> list[str]: |
| | ... |
| |
|
| | @property |
| | def aot_autograd_artifacts(self) -> list[str]: |
| | ... |
| |
|
| | @property |
| | def pgo_artifacts(self) -> list[str]: |
| | ... |
| |
|
| | @property |
| | def precompile_aot_autograd_artifacts(self) -> list[str]: |
| | ... |
| |
|
| | @property |
| | def precompile_dynamo_artifacts(self) -> list[str]: |
| | ... |
| |
|
| | def add(self, artifact: CacheArtifact) -> None: |
| | self.artifacts[artifact.type()].append(artifact.key) |
| |
|
| | def clear(self) -> None: |
| | self.artifacts.clear() |
| |
|
| | def empty(self) -> bool: |
| | return not self.artifacts |
| |
|
| |
|
| | def _serialize_single_cache( |
| | writer: BytesWriter, cls: "tuple[str, list[CacheArtifact]]" |
| | ) -> None: |
| | writer.write_str(cls[0]) |
| | writer.write_uint64(len(cls[1])) |
| | for artifact in cls[1]: |
| | CacheArtifact.serialize(writer, artifact) |
| |
|
| |
|
| | def _deserialize_single_cache( |
| | reader: BytesReader, |
| | ) -> "tuple[str, list[CacheArtifact]]": |
| | artifacts = [] |
| | artifact_type_key = reader.read_str() |
| | num_artifacts = reader.read_uint64() |
| | for _ in range(num_artifacts): |
| | artifacts.append(CacheArtifact.deserialize(artifact_type_key, reader)) |
| |
|
| | return artifact_type_key, artifacts |
| |
|
| |
|
| | CacheArtifactsResult = dict[str, list[CacheArtifact]] |
| |
|
| |
|
| | class CacheArtifactManager: |
| | """ |
| | Lightweight manager class for collecting and processing cache artifacts for |
| | hot loading |
| | |
| | Intended Lifecycle: |
| | - Execute code via torch.compile, this will call |
| | CacheArtifactManager.record_artifact on each cache artifact |
| | - Call CacheArtifactManager.serialize to convert all the cache artifacts |
| | to portable format |
| | - Call CacheArtifactManager.deserialize to hot load the cache artifacts on |
| | a potentially different process |
| | |
| | NOTE: There's no FB/FC guarantees, results of cache artifacts will not be |
| | used unless code version matches. |
| | """ |
| |
|
| | |
| | _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) |
| | |
| | |
| | _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() |
| | |
| | |
| | |
| | _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = ( |
| | AppendingByteSerializer(serialize_fn=_serialize_single_cache) |
| | ) |
| | _cache_info: CacheInfo = CacheInfo() |
| |
|
| | @classmethod |
| | def clear(cls) -> None: |
| | cls._new_cache_artifacts.clear() |
| | cls._seen_artifacts.clear() |
| | cls._serializer.clear() |
| | cls._cache_info.clear() |
| |
|
| | @classmethod |
| | @contextmanager |
| | def with_fresh_cache(cls) -> Generator[None, None, None]: |
| | original_new_cache_artifacts = cls._new_cache_artifacts |
| | original_seen_artifacts = cls._seen_artifacts |
| | original_serializer = cls._serializer |
| | original_cache_info = cls._cache_info |
| |
|
| | cls._new_cache_artifacts = defaultdict(list) |
| | cls._seen_artifacts = OrderedSet() |
| | cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache) |
| | cls._cache_info = cls._cache_info.__class__() |
| | try: |
| | yield |
| | finally: |
| | cls._new_cache_artifacts = original_new_cache_artifacts |
| | cls._seen_artifacts = original_seen_artifacts |
| | cls._serializer = original_serializer |
| | cls._cache_info = original_cache_info |
| |
|
| | @classmethod |
| | def record_artifact( |
| | cls, |
| | artifact_type: str, |
| | key: str, |
| | content: Any, |
| | ) -> None: |
| | """ |
| | Called from each caching operation to record the artifact in this |
| | "mega" list |
| | """ |
| | artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) |
| | if artifact in cls._seen_artifacts: |
| | return |
| | log.debug("Recording %s", str(artifact)) |
| | cls._new_cache_artifacts[artifact_type].append(artifact) |
| | cls._seen_artifacts.add(artifact) |
| |
|
| | @classmethod |
| | def need_serialize(cls) -> bool: |
| | """ |
| | Have we seen new artifacts since last serialize call? |
| | """ |
| | return len(cls._new_cache_artifacts) != 0 |
| |
|
| | @classmethod |
| | def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: |
| | """ |
| | Converts the "mega" list into portable format |
| | """ |
| | for artifact in chain(*cls._new_cache_artifacts.values()): |
| | log.debug("saving: %s", artifact) |
| | cls._cache_info.add(artifact) |
| |
|
| | if cls._cache_info.empty(): |
| | |
| | |
| | return None |
| |
|
| | try: |
| | |
| | |
| | info = copy.deepcopy(cls._cache_info) |
| | cls._serializer.extend(cls._new_cache_artifacts.items()) |
| | artifact_bytes = cls._serializer.to_bytes() |
| | cls._new_cache_artifacts.clear() |
| | return artifact_bytes, info |
| | except Exception: |
| | log.warning("Failed to pickle cache artifacts", exc_info=True) |
| | return None |
| |
|
| | @staticmethod |
| | def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]: |
| | """ |
| | Converts the portable format back into CacheArtifacts |
| | """ |
| | try: |
| | CacheArtifactManager._ensure_cache_artifacts_registered() |
| | artifacts = dict( |
| | AppendingByteSerializer.to_list( |
| | serialized_artifacts, |
| | deserialize_fn=_deserialize_single_cache, |
| | ) |
| | ) |
| | except Exception: |
| | log.warning("Failed to un-pickle cache artifacts", exc_info=True) |
| | return None |
| |
|
| | return artifacts |
| |
|
| | @staticmethod |
| | def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: |
| | info = CacheInfo() |
| | for artifact in chain(*artifacts.values()): |
| | log.debug("writing: %s", artifact) |
| | info.add(artifact) |
| | artifact.populate_cache() |
| |
|
| | return info |
| |
|
| | @classmethod |
| | def _ensure_cache_artifacts_registered(cls) -> None: |
| | """When deserializing caches in fresh process, we need to ensure that all |
| | cache artifacts are registered in the cache registry. This is done by |
| | simply importing all the cache artifacts already wrapped with register call. |
| | """ |
| | from torch._dynamo.pgo import PGOCacheArtifact |
| | from torch._functorch._aot_autograd.autograd_cache import ( |
| | AOTAutogradCacheArtifact, |
| | ) |
| | from torch._inductor.codecache import InductorCacheArtifact |
| | from torch._inductor.runtime.autotune_cache import ( |
| | AutotuneCacheArtifact, |
| | ) |
| |
|