Stateful Classes
Below are variations of a singleton class in the sense that all instances share the same state, which is initialized on the first instantiation.
These classes are immutable and store information about certain configurations or states.
Singleton class that has information about the current training environment and functions to help with process
control. Designed to be used when only process control and device execution states are needed. Does not need to
be initialized from Accelerator
.
Available attributes:
- device (
torch.device
) โ The device to use. - distributed_type (DistributedType) โ The type of distributed environment currently in use.
- local_process_index (
int
) โ The index of the current process on the current server. - mixed_precision (
str
) โ Whether or not the current script will use mixed precision, and if so the type of mixed precision being performed. - num_processes (
int
) โ The number of processes currently launched in parallel. - process_index (
int
) โ The index of the current process. - is_last_process (
bool
) โ Whether or not the current process is the last one. - is_main_process (
bool
) โ Whether or not the current process is the main one. - is_local_main_process (
bool
) โ Whether or not the current process is the main one on the local node.
Lets the local main process go inside a with block.
The other processes will enter the with block after the main process exits.
Example:
>>> from accelerate.state import PartialState
>>> state = PartialState()
>>> with state.local_main_process_first():
... # This will be printed first by local process 0 then in a seemingly
... # random order by the other processes.
... print(f"This will be printed by process {state.local_process_index}")
Lets the main process go first inside a with block.
The other processes will enter the with block after the main process exits.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> with accelerator.main_process_first():
... # This will be printed first by process 0 then in a seemingly
... # random order by the other processes.
... print(f"This will be printed by process {accelerator.process_index}")
on_last_process
< source >( function: typing.Callable[..., typing.Any] )
Decorator that only runs the decorated function on the last process.
on_local_main_process
< source >( function: typing.Callable[..., typing.Any] = None )
Decorator that only runs the decorated function on the local main process.
Example:
# Assume we have 2 servers with 4 processes each.
from accelerate.state import PartialState
state = PartialState()
@state.on_local_main_process
def print_something():
print("This will be printed by process 0 only on each server.")
print_something()
# On server 1:
"This will be printed by process 0 only"
# On server 2:
"This will be printed by process 0 only"
on_local_process
< source >( function: typing.Callable[..., typing.Any] = None local_process_index: int = None )
Decorator that only runs the decorated function on the process with the given index on the current node.
Example:
# Assume we have 2 servers with 4 processes each.
from accelerate import Accelerator
accelerator = Accelerator()
@accelerator.on_local_process(local_process_index=2)
def print_something():
print(f"Printed on process {accelerator.local_process_index}")
print_something()
# On server 1:
"Printed on process 2"
# On server 2:
"Printed on process 2"
on_main_process
< source >( function: typing.Callable[..., typing.Any] = None )
Decorator that only runs the decorated function on the main process.
on_process
< source >( function: typing.Callable[..., typing.Any] = None process_index: int = None )
Decorator that only runs the decorated function on the process with the given index.
Will stop the execution of the current process until every other process has reached that point (so this does nothing when the script is only run in one process). Useful to do before saving a model.
Example:
>>> # Assuming two GPU processes
>>> import time
>>> from accelerate.state import PartialState
>>> state = PartialState()
>>> if state.is_main_process:
... time.sleep(2)
>>> else:
... print("I'm waiting for the main process to finish its sleep...")
>>> state.wait_for_everyone()
>>> # Should print on every process at the same time
>>> print("Everyone is here")
class accelerate.state.AcceleratorState
< source >( mixed_precision: str = None cpu: bool = False dynamo_plugin = None deepspeed_plugin = None fsdp_plugin = None megatron_lm_plugin = None ipex_plugin = None _from_accelerator: bool = False **kwargs )
Singleton class that has information about the current training environment.
Available attributes:
- device (
torch.device
) โ The device to use. - distributed_type (DistributedType) โ The type of distributed environment currently in use.
- initialized (
bool
) โ Whether or not theAcceleratorState
has been initialized fromAccelerator
. - local_process_index (
int
) โ The index of the current process on the current server. - mixed_precision (
str
) โ Whether or not the current script will use mixed precision, and if so the type of mixed precision being performed. - num_processes (
int
) โ The number of processes currently launched in parallel. - process_index (
int
) โ The index of the current process. - is_last_process (
bool
) โ Whether or not the current process is the last one. - is_main_process (
bool
) โ Whether or not the current process is the main one. - is_local_main_process (
bool
) โ Whether or not the current process is the main one on the local node.
Lets the local main process go inside a with block.
The other processes will enter the with block after the main process exits.
Lets the main process go first inside a with block.
The other processes will enter the with block after the main process exits.
class accelerate.state.GradientState
< source >( gradient_accumulation_plugin: typing.Optional[accelerate.utils.dataclasses.GradientAccumulationPlugin] = None )
Singleton class that has information related to gradient synchronization for gradient accumulation
Available attributes:
- end_of_dataloader (
bool
) โ Whether we have reached the end the current dataloader - remainder (
int
) โ The number of extra samples that were added from padding the dataloader - sync_gradients (
bool
) โ Whether the gradients should be synced across all devices - active_dataloader (
Optional[DataLoader]
) โ The dataloader that is currently being iterated over - dataloader_references (
List[Optional[DataLoader]]
) โ A list of references to the dataloaders that are being iterated over - num_steps (
int
) โ The number of steps to accumulate over - adjust_scheduler (
bool
) โ Whether the scheduler should be adjusted to account for the gradient accumulation