|
#!/usr/bin/env bash |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if [ -z "${BASH_VERSION}" ]; then |
|
echo "Please use bash to run this script." >&2 |
|
exit 1 |
|
fi |
|
|
|
set -x |
|
|
|
SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" |
|
ROOT_DIR="$(dirname "${SCRIPT_DIR}")" |
|
export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" |
|
export LOGLEVEL="${LOGLEVEL:-WARNING}" |
|
|
|
ACTOR_MODEL_NAME_OR_PATH="PKU-Alignment/alpaca-7b-reproduced" |
|
REWARD_MODEL_NAME_OR_PATH="${ROOT_DIR}/output/rm" |
|
COST_MODEL_NAME_OR_PATH="${ROOT_DIR}/output/cm" |
|
unset {REWARD,COST}_CRITIC_MODEL_NAME_OR_PATH |
|
OUTPUT_DIR="${ROOT_DIR}/output/ppo-lag" |
|
unset HOSTFILE |
|
ZERO_STAGE=1 |
|
OFFLOAD="none" |
|
while [[ "$#" -gt 0 ]]; do |
|
arg="$1" |
|
shift |
|
case "${arg}" in |
|
--actor_model_name_or_path) |
|
ACTOR_MODEL_NAME_OR_PATH="$1" |
|
shift |
|
;; |
|
--actor_model_name_or_path=*) |
|
ACTOR_MODEL_NAME_OR_PATH="${arg#*=}" |
|
;; |
|
--reward_model_name_or_path) |
|
REWARD_MODEL_NAME_OR_PATH="$1" |
|
shift |
|
;; |
|
--reward_model_name_or_path=*) |
|
REWARD_MODEL_NAME_OR_PATH="${arg#*=}" |
|
;; |
|
--reward_critic_model_name_or_path) |
|
REWARD_CRITIC_MODEL_NAME_OR_PATH="$1" |
|
shift |
|
;; |
|
--reward_critic_model_name_or_path=*) |
|
REWARD_CRITIC_MODEL_NAME_OR_PATH="${arg#*=}" |
|
;; |
|
--cost_model_name_or_path) |
|
COST_MODEL_NAME_OR_PATH="$1" |
|
shift |
|
;; |
|
--cost_model_name_or_path=*) |
|
COST_MODEL_NAME_OR_PATH="${arg#*=}" |
|
;; |
|
--cost_critic_model_name_or_path) |
|
COST_CRITIC_MODEL_NAME_OR_PATH="$1" |
|
shift |
|
;; |
|
--cost_critic_model_name_or_path=*) |
|
COST_CRITIC_MODEL_NAME_OR_PATH="${arg#*=}" |
|
;; |
|
--output_dir) |
|
OUTPUT_DIR="$1" |
|
shift |
|
;; |
|
--output_dir=*) |
|
OUTPUT_DIR="${arg#*=}" |
|
;; |
|
--hostfile) |
|
HOSTFILE="$1" |
|
shift |
|
;; |
|
--hostfile=*) |
|
HOSTFILE="${arg#*=}" |
|
;; |
|
--zero_stage) |
|
ZERO_STAGE="$1" |
|
shift |
|
;; |
|
--zero_stage=*) |
|
ZERO_STAGE="${arg#*=}" |
|
;; |
|
--offload) |
|
OFFLOAD="$1" |
|
shift |
|
;; |
|
--offload=*) |
|
OFFLOAD="${arg#*=}" |
|
;; |
|
*) |
|
echo "Unknown parameter passed: '${arg}'" >&2 |
|
exit 1 |
|
;; |
|
esac |
|
done |
|
|
|
if [[ -z "${REWARD_CRITIC_MODEL_NAME_OR_PATH+x}" ]]; then |
|
REWARD_CRITIC_MODEL_NAME_OR_PATH="${REWARD_MODEL_NAME_OR_PATH}" |
|
fi |
|
if [[ -z "${COST_CRITIC_MODEL_NAME_OR_PATH+x}" ]]; then |
|
COST_CRITIC_MODEL_NAME_OR_PATH="${COST_MODEL_NAME_OR_PATH}" |
|
fi |
|
|
|
mkdir -p "${OUTPUT_DIR}" |
|
OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" |
|
if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then |
|
echo '*' >"${OUTPUT_DIR}/.gitignore" |
|
fi |
|
|
|
cp -f "$0" "${OUTPUT_DIR}/script.sh" |
|
|
|
if [[ -z "${WANDB_API_KEY}" ]]; then |
|
export WANDB_MODE="offline" |
|
fi |
|
|
|
MASTER_PORT_START=10000 |
|
MASTER_PORT_END=65535 |
|
MASTER_PORT="$( |
|
comm -23 \ |
|
<(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ |
|
<(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | |
|
shuf | head -n 1 |
|
)" |
|
|
|
DEEPSPEED_ARGS=() |
|
if [[ -n "${HOSTFILE+x}" ]]; then |
|
DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") |
|
fi |
|
DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") |
|
|
|
exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) |
|
|
|
deepspeed "${DEEPSPEED_ARGS[@]}" \ |
|
--master_port "${MASTER_PORT}" \ |
|
--module safe_rlhf.algorithms.ppo_lag \ |
|
--train_datasets PKU-SafeRLHF/train \ |
|
--ptx_datasets alpaca \ |
|
--actor_model_name_or_path "${ACTOR_MODEL_NAME_OR_PATH}" \ |
|
--reward_model_name_or_path "${REWARD_MODEL_NAME_OR_PATH}" \ |
|
--reward_critic_model_name_or_path "${REWARD_CRITIC_MODEL_NAME_OR_PATH}" \ |
|
--cost_model_name_or_path "${COST_MODEL_NAME_OR_PATH}" \ |
|
--cost_critic_model_name_or_path "${COST_CRITIC_MODEL_NAME_OR_PATH}" \ |
|
--max_length 512 \ |
|
--temperature 1.0 \ |
|
--num_return_sequences 1 \ |
|
--repetition_penalty 1.0 \ |
|
--trust_remote_code True \ |
|
--epochs 1 \ |
|
--update_iters 1 \ |
|
--per_device_prompt_batch_size 16 \ |
|
--per_device_train_batch_size 16 \ |
|
--gradient_accumulation_steps 1 \ |
|
--actor_lr 1e-5 \ |
|
--actor_weight_decay 0.01 \ |
|
--actor_lr_scheduler_type cosine \ |
|
--actor_lr_warmup_ratio 0.03 \ |
|
--actor_gradient_checkpointing \ |
|
--critic_lr 5e-6 \ |
|
--critic_weight_decay 0.0 \ |
|
--critic_lr_scheduler_type constant \ |
|
--critic_lr_warmup_ratio 0.03 \ |
|
--critic_gradient_checkpointing \ |
|
--normalize_reward False \ |
|
--normalize_cost False \ |
|
--seed 42 \ |
|
--threshold 0.0 \ |
|
--lambda_init 1.0 \ |
|
--lambda_lr 0.1 \ |
|
--lambda_max 5.0 \ |
|
--lambda_update_delay_steps 0 \ |
|
--episode_cost_window_size 128 \ |
|
--kl_coeff 0.01 \ |
|
--clip_range_ratio 0.2 \ |
|
--clip_range_score 50.0 \ |
|
--clip_range_value 5.0 \ |
|
--ptx_coeff 16.0 \ |
|
--output_dir "${OUTPUT_DIR}" \ |
|
--log_type wandb \ |
|
--log_project Safe-RLHF-PPO \ |
|
--zero_stage "${ZERO_STAGE}" \ |
|
--offload "${OFFLOAD}" \ |
|
--fp16 True |
|
|
|
|
|
|