|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from typing import Literal, Optional
|
|
|
|
import numpy as np
|
|
import tyro
|
|
from typing_extensions import Annotated
|
|
|
|
from trl.trainer.utils import exact_div
|
|
|
|
from ..core import flatten_dict
|
|
from ..import_utils import is_wandb_available
|
|
|
|
|
|
JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)]
|
|
|
|
|
|
@dataclass
|
|
class PPOConfig:
|
|
"""
|
|
Configuration class for PPOTrainer
|
|
"""
|
|
|
|
|
|
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
|
|
"""the name of this experiment (by default is the file name without the extension name)"""
|
|
seed: int = 0
|
|
"""Seed value for random generations"""
|
|
log_with: Optional[Literal["wandb", "tensorboard"]] = None
|
|
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
|
|
task_name: Optional[str] = None
|
|
"""Name of task to use - used only for tracking purposes"""
|
|
model_name: Optional[str] = "gpt2"
|
|
"""Name of model to use - used only for tracking purposes"""
|
|
query_dataset: Optional[str] = "imdb"
|
|
"""Name of dataset to query - used only for tracking purposes"""
|
|
reward_model: Optional[str] = "sentiment-analysis:lvwerra/distilbert-imdb"
|
|
"""The reward model to use - used only for tracking purposes"""
|
|
remove_unused_columns: bool = True
|
|
"""Remove unused columns from the dataset if `datasets.Dataset` is used"""
|
|
tracker_kwargs: JSONDict = field(default_factory=dict)
|
|
"""Keyword arguments for the tracker (e.g. python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'"""
|
|
accelerator_kwargs: JSONDict = field(default_factory=dict)
|
|
"""Keyword arguments for the accelerator"""
|
|
project_kwargs: JSONDict = field(default_factory=dict)
|
|
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
|
|
tracker_project_name: str = "trl"
|
|
"""Name of project to use for tracking"""
|
|
push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict)
|
|
"""Keyword arguments for pushing model to the hub during training (e.g. repo_id)"""
|
|
|
|
|
|
steps: int = 20000
|
|
"""Number of training steps"""
|
|
learning_rate: float = 1.41e-5
|
|
"""Adam learning rate"""
|
|
adap_kl_ctrl: bool = True
|
|
"""Use adaptive KL control, otherwise linear"""
|
|
init_kl_coef: Optional[float] = 0.2
|
|
"""Initial KL penalty coefficient (used for adaptive and linear control)"""
|
|
kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl"
|
|
"""kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"""
|
|
target: Optional[float] = 6
|
|
"""Target KL value for adaptive KL control"""
|
|
horizon: Optional[float] = 10000
|
|
"""Horizon for adaptive KL control"""
|
|
gamma: float = 1
|
|
"""Gamma parameter for advantage calculation"""
|
|
lam: float = 0.95
|
|
"""Lambda parameter for advantage calculation"""
|
|
cliprange: float = 0.2
|
|
"""Range for clipping in PPO policy gradient loss"""
|
|
cliprange_value: float = 0.2
|
|
"""Range for clipping values in loss calculation"""
|
|
vf_coef: float = 0.1
|
|
"""Scaling factor for value loss"""
|
|
batch_size: int = 128
|
|
"""Number of samples per optimisation step"""
|
|
forward_batch_size: Optional[int] = None
|
|
"""DEPRECATED: use `mini_batch_size` instead, which does the same thing."""
|
|
mini_batch_size: int = 128
|
|
"""Number of samples optimized in each mini batch"""
|
|
gradient_accumulation_steps: int = 1
|
|
"""The number of gradient accumulation steps"""
|
|
world_size: tyro.conf.Suppress[int] = None
|
|
"""The world size for distributed training"""
|
|
ppo_epochs: int = 4
|
|
"""Number of optimisation epochs per batch of samples"""
|
|
max_grad_norm: Optional[float] = None
|
|
"""Maximum gradient norm for gradient clipping"""
|
|
optimize_cuda_cache: Optional[bool] = None
|
|
"""DEPRECATED: use `optimize_device_cache` instead, which does the same thing."""
|
|
optimize_device_cache: Optional[bool] = False
|
|
"""Optimize device cache for slightly more memory-efficient training"""
|
|
early_stopping: bool = False
|
|
"""Whether to stop the PPO optimization loop early is the KL too high"""
|
|
target_kl: float = 1
|
|
"""Stop early if we exceed this value by over 50%"""
|
|
compare_steps: int = 1
|
|
"""Number of steps between comparison of the current reward with the best seen so far"""
|
|
ratio_threshold: float = 10.0
|
|
"""Skip mini-batches with high PPO ratios that can cause loss spikes"""
|
|
use_score_scaling: bool = False
|
|
"""Use score scaling"""
|
|
use_score_norm: bool = False
|
|
"""Use score normalization. Only applicable if use_score_scaling is True"""
|
|
score_clip: Optional[float] = None
|
|
"""Score clipping"""
|
|
whiten_rewards: bool = False
|
|
"""Whiten the rewards before compute advantages"""
|
|
|
|
|
|
is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None
|
|
"""TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model"""
|
|
is_peft_model: Optional[tyro.conf.Suppress[bool]] = None
|
|
"""TO BE FILLED In RUNTIME: Whether the model is a PEFT model"""
|
|
backward_batch_size: tyro.conf.Suppress[int] = None
|
|
"""TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call"""
|
|
global_backward_batch_size: tyro.conf.Suppress[int] = None
|
|
"""TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes"""
|
|
global_batch_size: tyro.conf.Suppress[int] = None
|
|
"""TO BE FILLED In RUNTIME: the effective `batch_size` across all processes"""
|
|
|
|
if optimize_cuda_cache is not None:
|
|
warnings.warn("The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead.")
|
|
optimize_device_cache = optimize_cuda_cache
|
|
else:
|
|
optimize_device_cache = False
|
|
|
|
def __post_init__(self):
|
|
if self.forward_batch_size is not None:
|
|
warnings.warn(
|
|
"Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization."
|
|
)
|
|
self.mini_batch_size = self.forward_batch_size
|
|
|
|
self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps
|
|
exact_div(
|
|
self.batch_size,
|
|
self.backward_batch_size,
|
|
"`batch_size`",
|
|
"`mini_batch_size * gradient_accumulation_steps`",
|
|
"`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`",
|
|
)
|
|
|
|
|
|
if self.log_with == "wandb":
|
|
|
|
if not is_wandb_available():
|
|
raise ImportError("Please install wandb to use wandb logging. You can do this by running `pip install wandb`.")
|
|
|
|
self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size))
|
|
assert self.kl_penalty in ["kl", "abs", "mse", "full"]
|
|
|
|
def to_dict(self):
|
|
output_dict = {}
|
|
for key, value in self.__dict__.items():
|
|
output_dict[key] = value
|
|
return flatten_dict(output_dict)
|
|
|