File size: 9,257 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
from dataclasses import dataclass
import os
import dataclasses
import io
import pickle
from typing import List, Union, Dict, cast
import torch
from torch import Tensor
from torch.futures import Future
from pathlib import Path
from .metadata import (
Metadata,
MetadataIndex,
)
from .storage import (
StorageReader,
StorageWriter,
WriteResult,
)
from .planner import (
LoadItemType,
LoadPlanner,
LoadPlan,
SavePlan,
SavePlanner,
ReadItem,
WriteItem,
WriteItemType,
)
from torch.distributed._shard._utils import narrow_tensor_by_index
@dataclass
class _StorageInfo:
"""
This is the per entry storage info
"""
relative_path: str
offset: int
length: int
@dataclass
class _StoragePrefix:
prefix: str
DEFAULT_SUFIX = ".distcp"
def _trim(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach().cpu()
if tensor.storage().size() != tensor.numel():
tensor = tensor.clone()
return tensor
def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult:
return WriteResult(
index=item.index,
size_in_bytes=size_in_bytes,
storage_data=storage_data)
def _write_item(stream, data, write_item, storage_key):
offset = stream.tell()
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
stream.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, stream)
length = stream.tell() - offset
return _result_from_write_item(
write_item,
length,
_StorageInfo(storage_key, offset, length)
)
def _write_files_from_queue(
file_queue: List,
planner: SavePlanner,
use_fsync: bool,
):
write_results = []
for file_path, file_name, write_items in file_queue:
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
with open(file_path, "wb") as stream:
for write_item in bytes_w:
data = planner.resolve_data(write_item)
write_results.append(_write_item(stream, data, write_item, file_name))
for write_item in tensor_w:
tensor = _trim(cast(torch.Tensor, planner.resolve_data(write_item)))
assert not tensor.is_cuda
write_results.append(_write_item(stream, tensor, write_item, file_name))
if use_fsync:
os.fsync(stream.fileno())
return write_results
class FileSystemWriter(StorageWriter):
"""
Basic implementation of StorageWriter using file IO.
This implementation makes the following assumptions and simplifications:
* The checkpoint path is an empty or non-existing directory.
* File creation is atomic
The checkpoint consist of one file per write request plus
a `.metadata` file with the serialized metadata.
"""
def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = False,
sync_files: bool = True,
) -> None:
"""
Initialize the writer pointing to `path`
Args:
path: diretory where the checkpoint will be writen to.
single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
sync_files: force files to be synced to permanent storage. Default to True.
N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
"""
super().__init__()
self.path = Path(path)
self.single_file_per_rank = single_file_per_rank
self.sync_files = sync_files
def init(self, is_coordinator: bool) -> None:
pass
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
# There's no storage input in the local plan
return plan
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
self.path.mkdir(parents=True, exist_ok=True)
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan)
]
return new_plans
def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[List[WriteResult]]:
storage_plan: _StoragePrefix = plan.storage_data
file_count = 0
def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFIX}"
file_count += 1
return file_name
file_queue = []
if self.single_file_per_rank:
file_name = gen_file()
file_queue.append((self.path / file_name, file_name, plan.items))
else:
for item in plan.items:
file_name = gen_file()
file_queue.append((self.path / file_name, file_name, [item]))
results = _write_files_from_queue(
file_queue=file_queue,
planner=planner,
use_fsync=self.sync_files,
)
fut: Future[List[WriteResult]] = Future()
fut.set_result(results)
return fut
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = dict()
for wr_list in results:
storage_md.update({
wr.index: wr.storage_data for wr in wr_list
})
metadata.storage_data = storage_md
with (self.path / ".metadata.tmp").open("wb") as metadata_file:
pickle.dump(metadata, metadata_file)
os.fsync(metadata_file.fileno())
(self.path / ".metadata.tmp").rename(self.path / ".metadata")
class SlicedBufferedReader(io.BufferedReader):
# TODO override read to handle (-1) correctly
def __init__(self, base_stream: io.RawIOBase, offset: int, len: int):
super().__init__(base_stream)
self.offset = offset
self.len = len
self.seek(0)
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
if __whence == os.SEEK_SET:
__offset = self.offset + __offset
elif __whence == os.SEEK_END:
__whence = os.SEEK_SET
__offset = (self.offset + self.len) - __offset
return super().seek(__offset, __whence)
def tell(self) -> int:
return super().tell() - self.offset
class FileSystemReader(StorageReader):
def __init__(self, path: Union[str, os.PathLike]) -> None:
super().__init__()
self.path = Path(path)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
def _slice_file(self, file, sinfo: _StorageInfo):
return SlicedBufferedReader(
io.FileIO(file.fileno(), closefd=False),
sinfo.offset, sinfo.length
)
def read_data(
self,
plan: LoadPlan,
planner: LoadPlanner
) -> Future[None]:
# group requests by file
per_file: Dict[str, List[ReadItem]] = dict()
for read_item in plan.items:
item_md = self.storage_data[read_item.storage_index]
path = item_md.relative_path
per_file.setdefault(path, []).append(read_item)
for relative_path, reqs in per_file.items():
with (self.path / relative_path).open("rb") as file:
# TODO sort by offset and cache the reading
for req in reqs:
item_md = self.storage_data[req.storage_index]
file_slice = self._slice_file(file, item_md)
if req.type == LoadItemType.BYTE_IO:
bytes = io.BytesIO(file_slice.read(item_md.length))
bytes.seek(0)
planner.load_bytes(req, bytes)
else:
tensor = cast(Tensor, torch.load(file_slice, map_location="cpu"))
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
target_tensor = planner.resolve_tensor(req).detach()
assert (
target_tensor.size() == tensor.size()
), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
fut: Future = Future()
fut.set_result(None)
return fut
# Implementating the abstract function in StorageReader
def read_metadata(self) -> Metadata:
with (self.path / ".metadata").open("rb") as metadata_file:
return pickle.load(metadata_file)
def init(self, metadata: Metadata, is_coordinator: bool) -> None:
self.storage_data = metadata.storage_data
assert self.storage_data is not None
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return global_plan
|