Spaces:
Running
Running
File size: 2,557 Bytes
80ebcb3 91fb4ef 80ebcb3 91fb4ef 80ebcb3 91fb4ef 80ebcb3 91fb4ef 80ebcb3 91fb4ef 80ebcb3 91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import io
from dataclasses import dataclass, field
from typing import Any, Dict, List
import torch
import torch.distributed.checkpoint.stateful
from .parallel import ParallelBackendType
from .utils import get_device_info
_device_type, _ = get_device_info()
@dataclass
class TrainState(torch.distributed.checkpoint.stateful.Stateful):
step: int = 0
observed_data_samples: int = 0
observed_num_tokens: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)
def state_dict(self) -> Dict[str, Any]:
# Only checkpoint global_avg_losses and global_max_losses per log frequency
# to avoid sync overhead in every iteration.
global_avg_losses_bytes = io.BytesIO()
torch.save(self.global_avg_losses, global_avg_losses_bytes)
global_max_losses_bytes = io.BytesIO()
torch.save(self.global_max_losses, global_max_losses_bytes)
log_steps_bytes = io.BytesIO()
torch.save(self.log_steps, log_steps_bytes)
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32),
"observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32),
"global_avg_losses": global_avg_losses_bytes,
"global_max_losses": global_max_losses_bytes,
"log_steps": log_steps_bytes,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
state_dict["global_avg_losses"].seek(0)
state_dict["global_max_losses"].seek(0)
state_dict["log_steps"].seek(0)
self.step = state_dict["step"].item()
self.observed_data_samples = state_dict["observed_data_samples"].item()
self.observed_num_tokens = state_dict["observed_num_tokens"].item()
self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False)
self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False)
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
@dataclass
class State:
# Parallel state
parallel_backend: ParallelBackendType = None
# Training state
train_state: TrainState = None
num_trainable_parameters: int = 0
generator: torch.Generator = None
# Hub state
repo_id: str = None
# Artifacts state
output_dir: str = None
|