gemma-3-1b-it-Math-GRPO / train_sft.py
NotoriousH2's picture
Add train_sft.py
d304eb5 verified
"""C17d: λͺ¨λ“  풀이 + 길이 ν•„ν„° (1500자 μ΄ν•˜λ§Œ) + NaN λ°©μ§€"""
import json, re, random, torch, numpy as np, os
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from transformers import EarlyStoppingCallback
from datasets import Dataset
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
if torch.cuda.get_device_capability()[0] >= 8: torch.set_float32_matmul_precision('high')
SP = "μ£Όμ–΄μ§„ μˆ˜ν•™ 문제λ₯Ό λ‹¨κ³„λ³„λ‘œ ν’€κ³  닡변을 μž‘μ„±ν•˜μ„Έμš”.\nλ°˜λ“œμ‹œ μ΅œμ’… 닡변을 \\boxed{μ •μˆ˜} ν˜•μ‹μœΌλ‘œ λ§ˆμ§€λ§‰ 쀄에 좜λ ₯ν•˜μ„Έμš”.\nμ˜ˆμ‹œ: \\boxed{42}"
print("=== C17d: All solutions, length-filtered (≀1500 chars) ===")
with open("data/GSM8K_full_qwen3_30b.json") as f:
data = json.load(f)
# 길이 ν•„ν„°: 1500자 μ΄ν•˜λ§Œ
filtered = [d for d in data if len(d['answer']) <= 1500]
print(f"원본: {len(data)}개 β†’ ν•„ν„° ν›„: {len(filtered)}개 (제거: {len(data)-len(filtered)})")
random.shuffle(filtered)
uq = len(set(d["question"] for d in filtered))
print(f"Unique: {uq}, avg {len(filtered)/uq:.1f}/q")
split = int(len(filtered) * 0.95)
train, test = filtered[:split], filtered[split:]
def to_sft(ex):
return {"prompt": [{"role":"user","content":SP+"\n\n"+ex["question"]}],
"completion": [{"role":"assistant","content":ex["answer"]}]}
cols = [c for c in Dataset.from_list(train[:1]).column_names if c not in ["prompt","completion"]]
train_ds = Dataset.from_list(train).map(to_sft, remove_columns=cols)
test_ds = Dataset.from_list(test).map(to_sft, remove_columns=cols)
print(f"ν•™μŠ΅: {len(train_ds)} / 검증: {len(test_ds)}")
tokenizer = AutoTokenizer.from_pretrained("outputs/models/gemma-3-1b-it")
model = AutoModelForCausalLM.from_pretrained("outputs/models/gemma-3-1b-it", dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2')
tokenizer.pad_token = tokenizer.eos_token
model.gradient_checkpointing_enable(); model.config.use_cache = False
cfg = SFTConfig(
report_to='none', seed=SEED, eval_strategy="steps", eval_steps=200,
save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="eval_loss",
save_steps=200, num_train_epochs=3, warmup_ratio=0.05, weight_decay=0.01, max_grad_norm=1.0,
neftune_noise_alpha=5, per_device_train_batch_size=8, gradient_accumulation_steps=4,
per_device_eval_batch_size=2, max_length=2048, lr_scheduler_type='cosine',
learning_rate=2e-5, bf16=True, optim="paged_adamw_8bit",
output_dir="outputs/c17d_checkpoints", logging_steps=50, save_strategy="steps",
)
trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=train_ds, eval_dataset=test_ds, args=cfg,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)])
print("ν•™μŠ΅ μ‹œμž‘ (3 epochs, λͺ¨λ“  풀이, ≀1500자)")
r = trainer.train()
print(f"μ™„λ£Œ! Loss: {r.training_loss:.4f}")
SAVE = "outputs/models/c17d-gemma-3-1b-it-Math"
os.makedirs(SAVE, exist_ok=True)
model.eval(); model.save_pretrained(SAVE, safe_serialization=False); tokenizer.save_pretrained(SAVE)
print(f"μ €μž₯: {SAVE}")
del model, trainer; torch.cuda.empty_cache()
print("GPU ν•΄μ œ")