course documentation
TRL တွင် GRPO ကို အကောင်အထည်ဖော်ခြင်း
TRL တွင် GRPO ကို အကောင်အထည်ဖော်ခြင်း
ဒီစာမျက်နှာမှာ၊ Transformer Reinforcement Learning (TRL) library ကို အသုံးပြုပြီး Group Relative Policy Optimization (GRPO) ကို ဘယ်လိုအကောင်အထည်ဖော်ရမလဲဆိုတာ လေ့လာသွားမှာပါ။ ကျွန်တော်တို့ဟာ code ကို အနည်းဆုံးနဲ့ လက်တွေ့အကောင်အထည်ဖော်ခြင်းကို အဓိကထားမှာပါ။
GRPO ရဲ့ အဓိကသဘောတရားတွေကို TRL ရဲ့ GRPOTrainer မှာ ဘယ်လိုပါဝင်နေလဲဆိုတာကို လေ့လာသွားမှာဖြစ်ပြီး၊ တရားဝင် TRL documentation က snippets တွေကို လမ်းညွှန်အဖြစ် အသုံးပြုပါမယ်။
ဒီအခန်းက TRL စတင်လေ့လာသူတွေအတွက် ရည်ရွယ်ပါတယ်။ သင် TRL ကို ကျွမ်းကျင်ပြီးသားဆိုရင်၊ GRPO ရဲ့ Open R1 implementation ကိုလည်း လေ့လာကြည့်နိုင်ပါတယ်။
ပထမဆုံးအနေနဲ့၊ GRPO algorithm ရဲ့ အရေးကြီးတဲ့ သဘောတရားအချို့ကို ပြန်လည်သတိရကြရအောင်။
- Group Formation: model က prompt တစ်ခုစီအတွက် completions များစွာကို ထုတ်လုပ်ပါတယ်။
- Preference Learning: model က completions အုပ်စုတွေကို နှိုင်းယှဉ်တဲ့ reward function ကနေ သင်ယူပါတယ်။
- Training Configuration: model က training process ကို ထိန်းချုပ်ဖို့ configuration တစ်ခုကို အသုံးပြုပါတယ်။
GRPO ကို အကောင်အထည်ဖော်ဖို့ ကျွန်တော်တို့ ဘာတွေလုပ်ဖို့ လိုအပ်မလဲ။
- prompts များ၏ dataset တစ်ခုကို သတ်မှတ်ပါ။
- completions စာရင်းကို ယူပြီး rewards စာရင်းကို ပြန်ပေးမယ့် reward function တစ်ခုကို သတ်မှတ်ပါ။
- training process ကို GRPOConfig တစ်ခုဖြင့် configure လုပ်ပါ။
- GRPOTrainer ကို အသုံးပြုပြီး model ကို train လုပ်ပါ။
GRPO training ကို စတင်ဖို့အတွက် အနိမ့်ဆုံး ဥပမာတစ်ခုကတော့ အောက်ပါအတိုင်းပါ။
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset
# 1. သင့် dataset ကို load လုပ်ပါ
dataset = load_dataset("your_dataset", split="train")
# 2. ရိုးရှင်းသော reward function တစ်ခုကို သတ်မှတ်ပါ
def reward_func(completions, **kwargs):
"""ဥပမာ- ပိုရှည်သော completions များကို ဆုချပါ"""
return [float(len(completion)) for completion in completions]
# 3. Training ကို Configure လုပ်ပါ
training_args = GRPOConfig(
output_dir="output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
logging_steps=10,
)
# 4. စတင်ပြီး train လုပ်ပါ
trainer = GRPOTrainer(
model="your_model", # ဥပမာ- "Qwen/Qwen2-0.5B-Instruct"
args=training_args,
train_dataset=dataset,
reward_funcs=reward_func,
)
trainer.train()အဓိက အစိတ်အပိုင်းများ
၁။ Dataset Format
သင့် dataset တွင် model က တုံ့ပြန်မည့် prompts များ ပါဝင်သင့်ပါတယ်။ GRPO trainer က prompt တစ်ခုစီအတွက် completions များစွာကို ထုတ်လုပ်ပြီး ၎င်းတို့ကို နှိုင်းယှဉ်ဖို့ reward function ကို အသုံးပြုပါလိမ့်မယ်။
၂။ Reward Function
reward function ဟာ အရေးကြီးပါတယ်။ ဒါက model က ဘယ်လိုသင်ယူတယ်ဆိုတာကို ဆုံးဖြတ်ပါတယ်။ လက်တွေ့ဥပမာ နှစ်ခုကတော့ အောက်ပါအတိုင်းပါ။
# ဥပမာ ၁- completion အရှည်ပေါ်မူတည်သော reward
def reward_length(completions, **kwargs):
return [float(len(completion)) for completion in completions]
# ဥပမာ ၂- pattern ကို ကိုက်ညီမှုအပေါ်မူတည်သော reward
import re
def reward_format(completions, **kwargs):
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
return [1.0 if re.match(pattern, c) else 0.0 for c in completions]၃။ Training Configuration
GRPOConfig တွင် ထည့်သွင်းစဉ်းစားရမည့် အဓိက parameters များ -
training_args = GRPOConfig(
# မရှိမဖြစ်လိုအပ်သော parameters များ
output_dir="output",
num_train_epochs=3,
num_generation=4, # prompt တစ်ခုစီအတွက် ထုတ်လုပ်မည့် completions အရေအတွက်
per_device_train_batch_size=4, # generations အားလုံးကို device batch တစ်ခုတည်းမှာ ရယူလိုပါသည်။
# ရွေးချယ်နိုင်သော သို့သော် အသုံးဝင်သော
gradient_accumulation_steps=2,
learning_rate=1e-5,
logging_steps=10,
# GRPO အတွက် သီးခြား (ရွေးချယ်နိုင်သော)
use_vllm=True, # generation ကို အရှိန်မြှင့်ရန်
)num_generation parameter က GRPO အတွက် အထူးအရေးကြီးပါတယ်။ ဒါက group size ကို သတ်မှတ်ပါတယ်။ ဆိုလိုတာက model က prompt တစ်ခုစီအတွက် မတူညီတဲ့ completions ဘယ်နှစ်ခု ထုတ်လုပ်မလဲဆိုတာပါပဲ။ ဒါက အခြား RL methods တွေနဲ့ ကွာခြားတဲ့ အဓိကအချက်ပါ။
- အလွန်နည်းပါးလွန်းခြင်း (ဥပမာ- ၂-၃ ခု)- အဓိပ္ပာယ်ရှိသော နှိုင်းယှဉ်မှုများအတွက် လုံလောက်သော ကွဲပြားမှု (diversity) ကို မပေးနိုင်ပါ။
- အကြံပြုထားသော (၄-၁၆ ခု)- ကွဲပြားမှုနှင့် တွက်ချက်မှု ထိရောက်မှု (computational efficiency) အကြား ကောင်းမွန်သော ဟန်ချက်ကို ပေးပါသည်။
- ပိုမိုကြီးမားသော တန်ဖိုးများ- သင်ယူမှုကို ပိုမိုကောင်းမွန်စေနိုင်သော်လည်း တွက်ချက်မှု ကုန်ကျစရိတ်ကို သိသိသာသာ တိုးမြှင့်စေသည်။
group size ကို သင့်ရဲ့ computational resources တွေနဲ့ task ရဲ့ ရှုပ်ထွေးမှုပေါ် မူတည်ပြီး ရွေးချယ်သင့်ပါတယ်။ ရိုးရှင်းတဲ့ tasks တွေအတွက်၊ သေးငယ်တဲ့ groups တွေ (၄-၈) က လုံလောက်နိုင်ပြီး၊ ပိုမိုရှုပ်ထွေးတဲ့ reasoning tasks တွေကတော့ ပိုမိုကြီးမားတဲ့ groups တွေ (၈-၁၆) ကနေ အကျိုးအမြတ်ရနိုင်ပါတယ်။
အောင်မြင်မှုအတွက် အကြံပြုချက်များ
၁။ Memory Management: သင်၏ GPU memory အပေါ်မူတည်၍ per_device_train_batch_size နှင့် gradient_accumulation_steps ကို ချိန်ညှိပါ။
၂။ Speed: သင်၏ model ကို ထောက်ပံ့ပါက ပိုမိုမြန်ဆန်သော generation အတွက် use_vllm=True ကို ဖွင့်ပါ။
၃။ Monitoring: training လုပ်နေစဉ်အတွင်း log လုပ်ထားသော metrics များကို စောင့်ကြည့်ပါ။
reward: completions များ၏ ပျမ်းမျှ reward။reward_std: reward groups များအတွင်းရှိ standard deviation။kl: reference model မှ KL divergence။
Reward Function ဒီဇိုင်း
DeepSeek R1 paper က သင်၏ GRPO implementation အတွက် လိုက်လျောညီထွေဖြစ်အောင် လုပ်ဆောင်နိုင်သော reward function ဒီဇိုင်းချခြင်း နည်းလမ်းများစွာကို ပြသထားပါတယ်။
၁။ Length-Based Rewards
အကောင်အထည်ဖော်ရအလွယ်ဆုံး rewards တွေထဲက တစ်ခုကတော့ length-based reward ပါပဲ။ ပိုရှည်တဲ့ completions တွေကို ဆုချနိုင်ပါတယ်။
def reward_len(completions, **kwargs):
ideal_length = 20
return [-abs(ideal_length - len(completion)) for completion in completions]ဒီ reward function က အလွန်တိုတောင်းလွန်းတဲ့ ဒါမှမဟုတ် အလွန်ရှည်လျားလွန်းတဲ့ completions တွေကို အပြစ်ပေးပါတယ်။ ဒါက model ကို ideal length ၂၀ tokens နဲ့ နီးစပ်တဲ့ completions တွေ ထုတ်လုပ်ဖို့ တိုက်တွန်းပါတယ်။
၂။ Verifiable Tasks အတွက် Rule-Based Rewards
သင်္ချာ သို့မဟုတ် coding ကဲ့သို့ တိကျမှန်ကန်သော အဖြေများရှိသည့် tasks များအတွက်၊ rule-based reward functions များကို အကောင်အထည်ဖော်နိုင်ပါတယ်။
def problem_reward(completions, answers, **kwargs):
"""Verifiable အဖြေများပါသော သင်္ချာပြဿနာများအတွက် reward function
completions: အကဲဖြတ်ရန် completions စာရင်း
answers: dataset မှ ပြဿနာများအတွက် အဖြေများစာရင်း
"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# completion မှ အဖြေကို ထုတ်ယူပါ
try:
# ဒါက ရိုးရှင်းတဲ့ ဥပမာတစ်ခုပါ - သင့်လျော်တဲ့ parsing လိုအပ်ပါလိမ့်မယ်
answer = extract_final_answer(completion)
# Binary reward: မှန်ရင် 1၊ မှားရင် 0
reward = 1.0 if answer == correct_answer else 0.0
rewards.append(reward)
except:
# အဖြေကို parse မလုပ်နိုင်ရင်၊ reward နည်းနည်း ပေးပါ
rewards.append(0.0)
return rewards၃။ Format-Based Rewards
DeepSeek R1 training မှာ အရေးကြီးခဲ့တဲ့ သင့်လျော်တဲ့ formatting ကိုလည်း ဆုချနိုင်ပါတယ်။
def format_reward(completions, **kwargs):
"""လိုချင်သော format ကို လိုက်နာသော completions များကို ဆုချပါ"""
# ဥပမာ- completion က think-then-answer format ကို လိုက်နာခြင်းရှိမရှိ စစ်ဆေးပါ
pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"
rewards = []
for completion in completions:
match = re.search(pattern, completion, re.DOTALL)
if match:
# sections နှစ်ခုလုံးမှာ အဓိက အကြောင်းအရာများ ရှိမရှိ စစ်ဆေးပါ
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
if len(think_content) > 20 and len(answer_content) > 0:
rewards.append(1.0)
else:
rewards.append(
0.5
) # မှန်ကန်သော format ဖြစ်သော်လည်း အကြောင်းအရာ နည်းပါးပါက partial reward
else:
rewards.append(0.0) # format မမှန်က reward မပေးပါ
return rewardsဒီဥပမာတွေက DeepSeek R1 training process ကနေ လှုံ့ဆော်မှုရယူပြီး မှန်ကန်မှု၊ formatting နဲ့ ပေါင်းစပ်ထားသော signals တွေကို အဓိကထားတဲ့ reward functions တွေကို ဘယ်လိုအကောင်အထည်ဖော်ရမယ်ဆိုတာကို ပြသထားပါတယ်။
ဒါပါပဲ!
နောက်အပိုင်းမှာ၊ TRL မှာ GRPO ကို အကောင်အထည်ဖော်ဖို့ လေ့ကျင့်ခန်းတစ်ခုကို သင်လိုက်လုပ်ရပါလိမ့်မယ်။
ဝေါဟာရ ရှင်းလင်းချက် (Glossary)
- GRPO (Group Relative Policy Optimization): Reinforcement Learning (RL) algorithm တစ်ခုဖြစ်ပြီး model က ထုတ်လုပ်လိုက်တဲ့ completions အုပ်စုတွေကို နှိုင်းယှဉ်ပြီး သင်ယူကာ model ရဲ့ policy ကို optimize လုပ်ပါတယ်။
- TRL (Transformer Reinforcement Learning) Library: Hugging Face မှ ထုတ်လုပ်ထားသော library တစ်ခုဖြစ်ပြီး Transformer models များကို Reinforcement Learning techniques ဖြင့် fine-tune လုပ်ရန် ရည်ရွယ်သည်။
- Implementation: သီအိုရီ သို့မဟုတ် algorithm တစ်ခုကို code အဖြစ် အကောင်အထည်ဖော်ခြင်း။
- GRPOTrainer: TRL library မှ GRPO algorithm ကို အကောင်အထည်ဖော်သော Trainer class။
- TRL Documentation: TRL library ၏ တရားဝင်မှတ်တမ်းများ။
- Open R1 Implementation: GRPO algorithm ၏ open-source အကောင်အထည်ဖော်မှု။
- Group Formation: model က prompt တစ်ခုစီအတွက် completions များစွာကို ထုတ်လုပ်ပြီး အုပ်စုဖွဲ့ခြင်း။
- Completions: model က prompt တစ်ခုကို တုံ့ပြန်တဲ့အနေနဲ့ ထုတ်လုပ်ပေးတဲ့ စာသား သို့မဟုတ် sequence များ။
- Preference Learning: reward function မှတဆင့် completions အုပ်စုများကို နှိုင်းယှဉ်ခြင်းဖြင့် model က သင်ယူသော လုပ်ငန်းစဉ်။
- Reward Function: model ၏ output (completions) များကို အကဲဖြတ်ပြီး ဂဏန်းတန်ဖိုး (reward) တစ်ခုကို ပြန်ပေးသော function။ ၎င်းသည် model ကို သင်ယူရာတွင် လမ်းညွှန်ပေးသည်။
- Training Configuration: training process အတွက် parameters များနှင့် settings များကို သတ်မှတ်ခြင်း။
- GRPOConfig: TRL library မှ GRPO training အတွက် configuration များကို ထိန်းချုပ်သော class။
- Dataset of Prompts: model က တုံ့ပြန်ရန်အတွက် အသုံးပြုမည့် prompts များပါဝင်သော dataset။
trl: Transformer Reinforcement Learning library။GRPOTrainer: TRL မှ GRPO algorithm အတွက် Trainer class။GRPOConfig: TRL မှ GRPO training အတွက် configuration class။load_dataset: Hugging Face Datasets library မှ dataset များကို load လုပ်ရန် function။output_dir: trained model နှင့် logs များကို သိမ်းဆည်းမည့် directory။num_train_epochs: training လုပ်မည့် epochs အရေအတွက်။per_device_train_batch_size: device တစ်ခုစီ (ဥပမာ- GPU) အတွက် batch size။gradient_accumulation_steps: gradients များကို update မလုပ်မီ batches မည်မျှစုဆောင်းမည်ကို သတ်မှတ်ခြင်း။logging_steps: training log များကို မည်သည့် step အရေအတွက်တိုင်းတွင် မှတ်တမ်းတင်မည်ကို သတ်မှတ်ခြင်း။model(argument inGRPOTrainer): အသုံးပြုမည့် base model ၏ identifier သို့မဟုတ် instance။args(argument inGRPOTrainer): training configuration arguments များ။train_dataset: training အတွက် အသုံးပြုမည့် dataset။reward_funcs: reward function (များ)။trainer.train(): training process ကို စတင်ရန် method။- Prompts: model ကို တုံ့ပြန်စေလိုသော စာသား input များ။
reModule: Python ၏ regular expression module။re.match(): string ၏ အစမှ pattern ကို ကိုက်ညီမှုရှိမရှိ စစ်ဆေးရန် function။num_generation: prompt တစ်ခုစီအတွက် model က ထုတ်လုပ်မည့် completions အရေအတွက်။ ၎င်းသည် GRPO ၏ group size ဖြစ်သည်။- RL Methods (Reinforcement Learning Methods): trial-and-error မှတစ်ဆင့် သင်ယူပြီး reward အများဆုံးရရှိရန် ကြိုးစားသော Machine Learning algorithms များ။
- Diversity: ထုတ်လုပ်လိုက်သော completions များ၏ ကွဲပြားမှု။
- Computational Efficiency: တွက်ချက်မှုအရင်းအမြစ်များကို မည်မျှထိရောက်စွာ အသုံးပြုသည်ကို ဆိုလိုသည်။
- Computational Cost: တွက်ချက်မှု လုပ်ဆောင်ရန် လိုအပ်သော အချိန်နှင့် အရင်းအမြစ်များ။
- Reasoning Tasks: အကြောင်းပြချက်၊ ဆင်ခြင်တုံတရား လိုအပ်သော လုပ်ငန်းများ။
- Memory Management: ကွန်ပျူတာ၏ မှတ်ဉာဏ် (memory) အသုံးပြုမှုကို ထိန်းချုပ်ခြင်း။
- GPU Memory: Graphics Processing Unit (GPU) တွင်ရှိသော မှတ်ဉာဏ်။
use_vllm=True: vLLM (a high-throughput inference engine) ကို အသုံးပြု၍ generation ကို အရှိန်မြှင့်ရန်။- Logged Metrics: training လုပ်နေစဉ်အတွင်း မှတ်တမ်းတင်ထားသော တိုင်းတာမှုများ။
reward(metric): completions များ၏ ပျမ်းမျှ reward တန်ဖိုး။reward_std(metric): reward groups များအတွင်းရှိ rewards များ၏ standard deviation။kl(metric): KL divergence (Kullback-Leibler divergence) ကို ရည်ညွှန်းပြီး reference model မှ policy က မည်မျှကွာခြားသည်ကို တိုင်းတာသည်။- DeepSeek R1 Paper: DeepSeek R1 model နှင့် ၎င်း၏ training method များကို ဖော်ပြထားသော research paper။
- Length-Based Reward: completion ၏ အရှည်ပေါ်မူတည်၍ ပေးသော reward။
ideal_length: completion အတွက် လိုချင်သော အရှည်။abs(): ဂဏန်းတစ်ခု၏ absolute value (အနုတ်လက္ခဏာမပါသော တန်ဖိုး)။- Verifiable Tasks: အဖြေကို တိကျစွာ စစ်ဆေးအတည်ပြုနိုင်သော လုပ်ငန်းများ။
- Rule-Based Reward Functions: သတ်မှတ်ထားသော စည်းမျဉ်းများ သို့မဟုတ် အခြေအနေများအပေါ် အခြေခံ၍ reward ပေးသော function များ။
extract_final_answer(): completion မှ နောက်ဆုံးအဖြေကို ထုတ်ယူရန် ဒီဇိုင်းထုတ်ထားသော function (ဥပမာတွင် ရိုးရှင်းထားသည်)။- Binary Reward: 0 သို့မဟုတ် 1 ကဲ့သို့သော တန်ဖိုးနှစ်ခုသာ ရှိသော reward (မှန်/မှား)။
- Parsing: စာသားကို ခွဲခြမ်းစိတ်ဖြာပြီး အဓိပ္ပာယ်ဖော်ခြင်း။
- Format-Based Rewards: completion ၏ formatting (ပုံစံချထားမှု) အပေါ်မူတည်၍ ပေးသော reward။
re.search(): string တစ်ခုအတွင်း pattern ကို ရှာဖွေရန် function။re.DOTALL: regular expression flags တစ်ခုဖြစ်ပြီး.(dot) သည် newline character (\n) အပါအဝင် မည်သည့် character ကိုမဆို ကိုက်ညီစေသည်။match.group(1)/match.group(2): regular expression match object မှ သက်ဆိုင်ရာ capture group ၏ contents များကို ထုတ်ယူခြင်း။strip(): string တစ်ခု၏ အစ သို့မဟုတ် အဆုံးရှိ whitespace များကို ဖယ်ရှားခြင်း။- Partial Reward: အပြည့်အဝ reward မဟုတ်ဘဲ တစ်စိတ်တစ်ပိုင်း reward။