| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| set -euo pipefail |
|
|
| |
| RUN_NAME="${RUN_NAME:-korean_3b_sft_v2}" |
| CONFIG="${CONFIG:-configs/korean_3b_sft_v2.yaml}" |
| BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}" |
| SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}" |
| VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}" |
| PRETRAIN_DATA="${PRETRAIN_DATA:-data/3b_train.bin}" |
| CKPT_DIR="checkpoints/${RUN_NAME}" |
| LOG_FILE="${CKPT_DIR}/train.log" |
| NPROC=8 |
| MASTER_PORT="${MASTER_PORT:-29504}" |
|
|
| MAX_STEPS=15000 |
| BATCH_SIZE=4 |
| GRAD_ACCUM=8 |
| LR="5.0e-5" |
| WARMUP_STEPS=2000 |
| WEIGHT_DECAY=0.05 |
| PRETRAIN_MIX_RATIO=0.3 |
| SEED=42 |
|
|
| EXTRA_ARGS="$@" |
|
|
| |
| export NCCL_IB_DISABLE=1 |
| export NCCL_ALGO=Ring |
| export NCCL_PROTO=Simple |
| export NCCL_MIN_NCHANNELS=16 |
| export NCCL_MAX_NCHANNELS=16 |
| export NCCL_BUFFSIZE=67108864 |
| export OMP_NUM_THREADS=4 |
| export MKL_NUM_THREADS=4 |
|
|
| |
| export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |
|
|
| cd "$(dirname "$0")/.." |
|
|
| |
| if [[ ! -d "${BASE_CHECKPOINT}" ]]; then |
| echo "==================================================================" |
| echo " ERROR: Base checkpoint not found: ${BASE_CHECKPOINT}" |
| echo " Set BASE_CHECKPOINT env var or use --base_checkpoint CLI arg." |
| echo "==================================================================" |
| exit 1 |
| fi |
|
|
| if [[ ! -f "${SFT_DATA}" ]]; then |
| echo "==================================================================" |
| echo " ERROR: SFT data not found: ${SFT_DATA}" |
| echo " Run: bash scripts/prepare_sft_combined.sh" |
| echo "==================================================================" |
| exit 1 |
| fi |
|
|
| if [[ ! -f "${PRETRAIN_DATA}" ]]; then |
| echo "==================================================================" |
| echo " ERROR: Pretrain data not found: ${PRETRAIN_DATA}" |
| echo " Set PRETRAIN_DATA env var to the correct path." |
| echo "==================================================================" |
| exit 1 |
| fi |
|
|
| |
| if [[ ! -f "${VAL_DATA}" ]]; then |
| VAL_FALLBACK="data/sft_combined/val.jsonl" |
| if [[ -f "${VAL_FALLBACK}" ]]; then |
| VAL_DATA="${VAL_FALLBACK}" |
| echo "[INFO] val_filtered not found, fallback: ${VAL_DATA}" |
| else |
| echo "ERROR: VAL_DATA not found: ${VAL_DATA}" |
| exit 1 |
| fi |
| fi |
|
|
| mkdir -p "${CKPT_DIR}" |
|
|
| echo "==================================================================" |
| echo " 3B SFT v2 Fine-Tuning" |
| echo " Run name : ${RUN_NAME}" |
| echo " Config : ${CONFIG}" |
| echo " Base checkpoint : ${BASE_CHECKPOINT}" |
| echo " SFT data : ${SFT_DATA}" |
| echo " Pretrain data : ${PRETRAIN_DATA}" |
| echo " Val data : ${VAL_DATA}" |
| echo " CKPT dir : ${CKPT_DIR}" |
| echo " Log file : ${LOG_FILE}" |
| echo " Max steps : ${MAX_STEPS}" |
| echo " Batch size : ${BATCH_SIZE} (local) x ${NPROC} GPU x ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch" |
| echo " Learning rate : ${LR}" |
| echo " Weight decay : ${WEIGHT_DECAY}" |
| echo " Warmup : ${WARMUP_STEPS} steps" |
| echo " Data mixing : $((100 - ${PRETRAIN_MIX_RATIO%.*}0))% SFT + ${PRETRAIN_MIX_RATIO}00% pretrain" |
| echo " Master port : ${MASTER_PORT}" |
| echo " ALLOC_CONF : ${PYTORCH_CUDA_ALLOC_CONF}" |
| echo " Started : $(date)" |
| echo "==================================================================" |
|
|
| export PYTHONWARNINGS="ignore::UserWarning:torch.library" |
|
|
| torchrun \ |
| --nproc_per_node=${NPROC} \ |
| --master_port=${MASTER_PORT} \ |
| train/sft.py \ |
| --config "${CONFIG}" \ |
| --base_checkpoint "${BASE_CHECKPOINT}" \ |
| --sft_data "${SFT_DATA}" \ |
| --val_data "${VAL_DATA}" \ |
| --pretrain_data "${PRETRAIN_DATA}" \ |
| --pretrain_mix_ratio ${PRETRAIN_MIX_RATIO} \ |
| --checkpoint_dir "${CKPT_DIR}" \ |
| --log_file "${LOG_FILE}" \ |
| --max_steps ${MAX_STEPS} \ |
| --batch_size ${BATCH_SIZE} \ |
| --grad_accum ${GRAD_ACCUM} \ |
| --lr ${LR} \ |
| --weight_decay ${WEIGHT_DECAY} \ |
| --warmup_steps ${WARMUP_STEPS} \ |
| --seed ${SEED} \ |
| --use_fp8 \ |
| ${EXTRA_ARGS} \ |
| 2>&1 | grep -v "UserWarning" \ |
| | grep -v "Warning only once" \ |
| | grep -v "Overriding a previously" \ |
| | grep -v "dispatch key:" \ |
| | grep -v "previous kernel:" \ |
| | grep -v "new kernel:" \ |
| | grep -v "operator: flash_attn" \ |
| | grep -v "registered at /usr/local" \ |
| | grep -v "self.m.impl" \ |
| | tee -a "${LOG_FILE}" |
|
|
| echo "==================================================================" |
| echo " 3B SFT v2 Done : $(date)" |
| echo "==================================================================" |
|
|