|
import abc |
|
from dataclasses import dataclass |
|
from typing import List, Any |
|
|
|
from torch.futures import Future |
|
|
|
from .metadata import ( |
|
Metadata, |
|
MetadataIndex, |
|
) |
|
|
|
from .planner import ( |
|
LoadPlan, |
|
SavePlan, |
|
SavePlanner, |
|
LoadPlanner, |
|
) |
|
|
|
@dataclass(frozen=True) |
|
class WriteResult: |
|
index: MetadataIndex |
|
|
|
size_in_bytes: int |
|
storage_data: Any |
|
|
|
class StorageWriter(abc.ABC): |
|
""" |
|
Interface used by ``save_state_dict`` to write to storage. |
|
|
|
One StorageWriter instance acts as both the coordinator and the follower |
|
in a distributed checkpoint. As part of initialization, each instance |
|
is told its role. |
|
|
|
A subclass should expect the following sequence of calls. |
|
|
|
1) (all ranks) init() |
|
2) (all ranks) prepare_local_plan() |
|
3) (coordinator) prepare_global_plan() |
|
4) (all ranks) write_data() |
|
5) (coordinator) finish() |
|
""" |
|
|
|
@abc.abstractmethod |
|
def init(self, is_coordinator: bool) -> None: |
|
""" |
|
Initialize this instance. |
|
|
|
Args: |
|
is_coordinator (bool): Whether this instance is reponsible for coordinating |
|
the checkpoint. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prepare_local_plan(self, plan: SavePlan) -> SavePlan: |
|
""" |
|
Perform storage-specific local planning. |
|
|
|
While this method can produce a completely different plan, the recomended |
|
way is to store storage specific data in SavePlan::storage_data. |
|
|
|
Args: |
|
plan (SavePlan): The local plan from the ``SavePlanner`` in use. |
|
|
|
Returns: |
|
A transformed ``SavePlan`` after storage local planning |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: |
|
""" |
|
Perform centralized planning of storage. |
|
|
|
This method is only called on the coordinator instance. |
|
|
|
While this method can produce a completely different plan, the prefered |
|
way is to store storage specific data in SavePlan::storage_data. |
|
|
|
Args: |
|
plans: A list of ``SavePlan`` instances, one for each rank. |
|
|
|
Returns: |
|
A list of transformed ``SavePlan`` after storage global planning |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def write_data( |
|
self, |
|
plan: SavePlan, |
|
planner: SavePlanner |
|
) -> Future[List[WriteResult]]: |
|
""" |
|
Write all items from ``plan`` using ``planner`` to resolve the data. |
|
|
|
A subclass should call ``SavePlanner::resolve_data`` on each item |
|
from the plan to get access to the underlying object to write. |
|
|
|
Subclasses should lazily call `resolve_data` as it can allocate memory. |
|
In case of tensors, make following assuptions: |
|
|
|
- They might be on any device, including not matching the one on ``WriteItem::tensor_data`` |
|
- They might be views or not contiguous. Only the projection needs to be saved. |
|
|
|
Args: |
|
plan (SavePlan): The save plan to execute. |
|
planner (SavePlanner): Planner object to be used to resolve items to data. |
|
|
|
Returns: |
|
A future that completes to a list of WriteResult |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: |
|
""" |
|
Writes the metadata and marks the current checkpoint as sucessful. |
|
|
|
The actual format/schema used for serializing `metadata` is an |
|
implemetation detail. The only requirement is that it's recoverable |
|
in to the same object graph. |
|
|
|
Args: |
|
metadata (Metadata): metadata for the new checkpoint |
|
results: A list of WriteResults from all ranks. |
|
|
|
Returns: |
|
None |
|
""" |
|
pass |
|
|
|
class StorageReader(abc.ABC): |
|
""" |
|
Interface used by ``load_state_dict`` to read from storage. |
|
|
|
One StorageReader instance acts as both the coordinator and the follower |
|
in a distributed checkpoint. As part of initialization, each instance |
|
is told its role. |
|
|
|
A subclass should expected the following sequence of calls by ``load_state_dict``: |
|
|
|
1) (all ranks) read_metadata() |
|
2) (all ranks) init |
|
3) (all ranks) prepare_local_plan |
|
4) (coordinator) prepare_global_plan |
|
5) (all ranks) read_data |
|
""" |
|
@abc.abstractmethod |
|
def read_metadata(self) -> Metadata: |
|
""" |
|
Reads the checkpoint metadata. |
|
|
|
Returns: |
|
The metatada object associated with the checkpoint being loaded. |
|
|
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def init(self, metadata: Metadata, is_coordinator: bool) -> None: |
|
""" |
|
Initialize this instance. |
|
|
|
Args: |
|
metadata (Metadata): The metadata schema to use. |
|
is_coordinator (bool): Whether this instance is reponsible for coordinating |
|
the checkpoint. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: |
|
""" |
|
Perform storage-specific local planning. |
|
|
|
While this method can produce a completely different plan, the recomended |
|
way is to store storage specific data in LoadPlan::storage_data. |
|
|
|
Args: |
|
plan (LoadPlan): The local plan from the ``LoadPlan`` in use. |
|
|
|
Returns: |
|
A transformed ``LoadPlan`` after storage local planning |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: |
|
""" |
|
Perform centralized planning of storage loading. |
|
|
|
This method is only called on the coordinator instance. |
|
|
|
While this method can produce a completely different plan, the prefered |
|
way is to store storage specific data in LoadPlan::storage_data. |
|
|
|
Args: |
|
plans: A list of ``LoadPlan`` instances, one for each rank. |
|
|
|
Returns: |
|
A list of transformed ``LoadPlan`` after storage global planning |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: |
|
""" |
|
Reads all items from ``plan`` using ``planner`` to resolve the data. |
|
|
|
A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO |
|
object into the right place. |
|
|
|
A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the |
|
tensors that in should load data into. |
|
|
|
It's the StorageLayer responsibility to properly schedule any cross device copies |
|
required. |
|
|
|
Args: |
|
plan (LoadPlan): The local plan to execute on |
|
planner (LoadPlanner): The planner object to use to resolve items. |
|
|
|
Returns: |
|
A future that completes once all reads are finished. |
|
""" |
|
pass |
|
|