|
|
|
|
|
|
|
import contextlib |
|
import logging |
|
import math |
|
import warnings |
|
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union |
|
|
|
from composer.utils import dist |
|
from omegaconf import DictConfig, ListConfig |
|
from omegaconf import OmegaConf as om |
|
|
|
from llmfoundry.models.utils import init_empty_weights |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def pop_config(cfg: DictConfig, |
|
key: str, |
|
must_exist: bool = True, |
|
default_value: Any = None, |
|
convert: bool = False) -> Any: |
|
"""Pop a value from the main config file and return it. |
|
|
|
If the key does not exist, return the default_value or raise a RuntimeError |
|
depending on the must_exist flag. If the convert flag is set to True, then |
|
we will convert the value to a python object using OmegaConf.to_container. |
|
""" |
|
value = cfg.pop(key, None) |
|
if value is not None and convert: |
|
if not isinstance(value, DictConfig) and not isinstance( |
|
value, ListConfig): |
|
raise ValueError( |
|
f'The key {key} has a value of type {type(value)} that cannot be \ |
|
converted to a dict or list. Please check your yaml.' |
|
) |
|
return om.to_container(value) |
|
elif value is not None: |
|
return value |
|
elif must_exist: |
|
raise NameError( |
|
f'The {key} parameter is missing and must exist for execution. Please check your yaml.' |
|
) |
|
else: |
|
return default_value |
|
|
|
|
|
def calculate_batch_size_info( |
|
global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']] |
|
) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: |
|
if global_batch_size % dist.get_world_size() != 0: |
|
raise ValueError( |
|
f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' |
|
+ |
|
'as a result, the batch size would be truncated, please adjust `global_batch_size` ' |
|
+ f'to be divisible by world size, {dist.get_world_size()}.') |
|
device_batch_size = global_batch_size // dist.get_world_size() |
|
if device_microbatch_size == 'auto': |
|
device_grad_accum = 'auto' |
|
elif isinstance(device_microbatch_size, int): |
|
if device_microbatch_size > device_batch_size: |
|
log.warn( |
|
f'device_microbatch_size > device_batch_size, ' + |
|
f'will be reduced from {device_microbatch_size} -> {device_batch_size}.' |
|
) |
|
device_microbatch_size = device_batch_size |
|
device_grad_accum = math.ceil(device_batch_size / |
|
device_microbatch_size) |
|
else: |
|
raise ValueError(f'Not sure how to parse {device_microbatch_size=}') |
|
|
|
return device_batch_size, device_microbatch_size, device_grad_accum |
|
|
|
|
|
|
|
def update_batch_size_info(cfg: DictConfig) -> DictConfig: |
|
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( |
|
cfg.global_train_batch_size, cfg.device_train_microbatch_size) |
|
cfg.n_gpus = dist.get_world_size() |
|
cfg.device_train_batch_size = device_train_batch_size |
|
cfg.device_train_microbatch_size = device_train_microbatch_size |
|
cfg.device_train_grad_accum = device_train_grad_accum |
|
|
|
if 'device_eval_batch_size' not in cfg: |
|
if cfg.device_train_microbatch_size == 'auto': |
|
cfg.device_eval_batch_size = 1 |
|
else: |
|
cfg.device_eval_batch_size = cfg.device_train_microbatch_size |
|
return cfg |
|
|
|
|
|
def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): |
|
|
|
|
|
|
|
|
|
init_context = contextlib.nullcontext() |
|
if 'init_device' in model_cfg: |
|
assert model_cfg.init_device in ['meta', 'cpu', 'mixed'] |
|
if fsdp_config is None and model_cfg.init_device == 'meta': |
|
warnings.warn( |
|
"Using `cfg.model.init_device='meta'` is only valid when using FSDP! " +\ |
|
"Reverting to `cfg.model.init_device='cpu'`.") |
|
model_cfg.init_device = 'cpu' |
|
if model_cfg.init_device == 'meta': |
|
init_context = init_empty_weights() |
|
if model_cfg.init_device == 'mixed': |
|
if fsdp_config is None: |
|
raise NotImplementedError( |
|
'Using init_device `mixed` is only supported with FSDP. ' + |
|
'Please add a FSDP config.') |
|
|
|
if not fsdp_config.get('sync_module_states', False): |
|
warnings.warn(( |
|
'Setting `sync_module_states = True` for FSDP. This is required ' |
|
'when using mixed initialization.')) |
|
fsdp_config['sync_module_states'] = True |
|
|
|
|
|
fsdp_config.setdefault('use_orig_params', False) |
|
fsdp_config.setdefault('load_monolith_rank0_only', True) |
|
|
|
|
|
master_dtype = model_cfg.get('master_weights_dtype') |
|
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', |
|
'amp_bf16') |
|
if fsdp_config and master_dtype in small_dtypes: |
|
reduce_dtype = None |
|
buffer_dtype = None |
|
mixed_precision = fsdp_config.get('mixed_precision') |
|
if isinstance(mixed_precision, Mapping): |
|
reduce_dtype = mixed_precision.get('reduce_dtype') |
|
buffer_dtype = mixed_precision.get('buffer_dtype') |
|
fsdp_config['mixed_precision'] = { |
|
'param_dtype': None, |
|
'reduce_dtype': reduce_dtype, |
|
'buffer_dtype': buffer_dtype, |
|
'keep_low_precision_grads': True, |
|
} |
|
|
|
return init_context |
|
|
|
|
|
def log_config(cfg: DictConfig) -> None: |
|
"""Logs the current config and updates the wandb and mlflow configs. |
|
|
|
This function can be called multiple times to update the wandb and MLflow |
|
config with different variables. |
|
""" |
|
print(om.to_yaml(cfg)) |
|
if 'wandb' in cfg.get('loggers', {}): |
|
try: |
|
import wandb |
|
except ImportError as e: |
|
raise e |
|
if wandb.run: |
|
wandb.config.update(om.to_container(cfg, resolve=True)) |
|
|
|
if 'mlflow' in cfg.get('loggers', {}): |
|
try: |
|
import mlflow |
|
except ImportError as e: |
|
raise e |
|
if mlflow.active_run(): |
|
mlflow.log_params(params=om.to_container(cfg, resolve=True)) |
|
|