|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
An naive implementation of split placment example |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
from copy import deepcopy |
|
|
from pprint import pprint |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from verl import DataProto |
|
|
from verl.trainer.ppo.ray_trainer import ( |
|
|
AdvantageEstimator, |
|
|
apply_kl_penalty, |
|
|
compute_advantage, |
|
|
compute_data_metrics, |
|
|
compute_timing_metrics, |
|
|
marked_timer, |
|
|
) |
|
|
from verl.utils.metric import reduce_metrics |
|
|
|
|
|
|
|
|
def fit(self): |
|
|
""" |
|
|
The training loop of PPO. |
|
|
The driver process only need to call the compute functions of the worker group through RPC |
|
|
to construct the PPO dataflow. |
|
|
The light-weight advantage computation is done on the driver process. |
|
|
""" |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from verl.utils.tracking import Tracking |
|
|
|
|
|
logger = Tracking( |
|
|
project_name=self.config.trainer.project_name, |
|
|
experiment_name=self.config.trainer.experiment_name, |
|
|
default_backend=self.config.trainer.logger, |
|
|
config=OmegaConf.to_container(self.config, resolve=True), |
|
|
) |
|
|
|
|
|
self.global_steps = 0 |
|
|
|
|
|
|
|
|
self._load_checkpoint() |
|
|
|
|
|
|
|
|
|
|
|
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): |
|
|
val_metrics = self._validate() |
|
|
pprint(f"Initial validation metrics: {val_metrics}") |
|
|
logger.log(data=val_metrics, step=self.global_steps) |
|
|
if self.config.trainer.get("val_only", False): |
|
|
return |
|
|
|
|
|
|
|
|
self.global_steps += 1 |
|
|
last_val_metrics = None |
|
|
|
|
|
for epoch in range(self.config.trainer.total_epochs): |
|
|
for batch_dict in self.train_dataloader: |
|
|
metrics = {} |
|
|
timing_raw = {} |
|
|
|
|
|
batch: DataProto = DataProto.from_single_dict(batch_dict) |
|
|
|
|
|
|
|
|
gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) |
|
|
is_last_step = self.global_steps >= self.total_training_steps |
|
|
|
|
|
with marked_timer("step", timing_raw): |
|
|
|
|
|
with marked_timer("gen", timing_raw): |
|
|
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) |
|
|
timing_raw.update(gen_batch_output.meta_info["timing"]) |
|
|
gen_batch_output.meta_info.pop("timing", None) |
|
|
|
|
|
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: |
|
|
with marked_timer("gen_max", timing_raw): |
|
|
gen_baseline_batch = deepcopy(gen_batch) |
|
|
gen_baseline_batch.meta_info["do_sample"] = False |
|
|
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) |
|
|
|
|
|
batch = batch.union(gen_baseline_output) |
|
|
reward_baseline_tensor = self.reward_fn(batch) |
|
|
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) |
|
|
|
|
|
batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) |
|
|
|
|
|
batch.batch["reward_baselines"] = reward_baseline_tensor |
|
|
|
|
|
del gen_baseline_batch, gen_baseline_output |
|
|
|
|
|
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) |
|
|
|
|
|
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) |
|
|
batch = batch.union(gen_batch_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._balance_batch(batch, metrics=metrics) |
|
|
|
|
|
|
|
|
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() |
|
|
|
|
|
|
|
|
with marked_timer("old_log_prob", timing_raw): |
|
|
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) |
|
|
batch = batch.union(old_log_prob) |
|
|
|
|
|
if self.use_reference_policy: |
|
|
|
|
|
with marked_timer("ref", timing_raw): |
|
|
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) |
|
|
batch = batch.union(ref_log_prob) |
|
|
|
|
|
|
|
|
if self.use_critic: |
|
|
with marked_timer("values", timing_raw): |
|
|
values = self.critic_wg.compute_values(batch) |
|
|
batch = batch.union(values) |
|
|
|
|
|
with marked_timer("adv", timing_raw): |
|
|
|
|
|
|
|
|
|
|
|
if self.use_rm: |
|
|
|
|
|
reward_tensor = self.rm_wg.compute_rm_score(batch) |
|
|
batch = batch.union(reward_tensor) |
|
|
|
|
|
|
|
|
reward_tensor = self.reward_fn(batch) |
|
|
batch.batch["token_level_scores"] = reward_tensor |
|
|
|
|
|
|
|
|
if self.config.algorithm.use_kl_in_reward: |
|
|
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) |
|
|
metrics.update(kl_metrics) |
|
|
else: |
|
|
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] |
|
|
|
|
|
|
|
|
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) |
|
|
batch = compute_advantage( |
|
|
batch, |
|
|
adv_estimator=self.config.algorithm.adv_estimator, |
|
|
gamma=self.config.algorithm.gamma, |
|
|
lam=self.config.algorithm.lam, |
|
|
num_repeat=self.config.actor_rollout_ref.rollout.n, |
|
|
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_critic: |
|
|
with marked_timer("update_critic_call", timing_raw): |
|
|
critic_output = self.critic_wg.update_critic(batch) |
|
|
|
|
|
|
|
|
if self.config.trainer.critic_warmup <= self.global_steps: |
|
|
|
|
|
with marked_timer("update_actor_call", timing_raw): |
|
|
actor_output = self.actor_rollout_wg.update_actor(batch) |
|
|
|
|
|
|
|
|
with marked_timer("update_actor_critic", timing_raw): |
|
|
critic_output = critic_output.get() |
|
|
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) |
|
|
metrics.update(critic_output_metrics) |
|
|
|
|
|
actor_output = actor_output.get() |
|
|
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) |
|
|
metrics.update(actor_output_metrics) |
|
|
|
|
|
|
|
|
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): |
|
|
with marked_timer("testing", timing_raw): |
|
|
val_metrics: dict = self._validate() |
|
|
if is_last_step: |
|
|
last_val_metrics = val_metrics |
|
|
metrics.update(val_metrics) |
|
|
|
|
|
if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): |
|
|
with marked_timer("save_checkpoint", timing_raw): |
|
|
self._save_checkpoint() |
|
|
|
|
|
|
|
|
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) |
|
|
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) |
|
|
|
|
|
|
|
|
logger.log(data=metrics, step=self.global_steps) |
|
|
|
|
|
if self.global_steps >= self.total_training_steps: |
|
|
pprint(f"Final validation metrics: {last_val_metrics}") |
|
|
return |
|
|
|
|
|
self.global_steps += 1 |
|
|
|